IAM tech spec: Auth and access management current state and proposed

changes.

Support for separate workspaces

Addition of workspace CLI support for test purposes
This commit is contained in:
Cyber MacGeddon 2026-04-18 23:07:26 +01:00
parent 48da6c5f8b
commit db05427d0e
219 changed files with 4875 additions and 2616 deletions

1
.gitignore vendored
View file

@ -15,4 +15,5 @@ trustgraph-parquet/trustgraph/parquet_version.py
trustgraph-vertexai/trustgraph/vertexai_version.py
trustgraph-unstructured/trustgraph/unstructured_version.py
trustgraph-mcp/trustgraph/mcp_version.py
trustgraph/trustgraph/trustgraph_version.py
vertexai/

View file

@ -57,7 +57,7 @@ container-bedrock container-vertexai \
container-hf container-ocr \
container-unstructured container-mcp
some-containers: container-base container-flow
some-containers: container-base container-flow container-unstructured
push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}

View file

@ -0,0 +1,309 @@
---
layout: default
title: "Data Ownership and Information Separation"
parent: "Tech Specs"
---
# Data Ownership and Information Separation
## Purpose
This document defines the logical ownership model for data in
TrustGraph: what the artefacts are, who owns them, and how they relate
to each other.
The IAM spec ([iam.md](iam.md)) describes authentication and
authorisation mechanics. This spec addresses the prior question: what
are the boundaries around data, and who owns what?
## Concepts
### Workspace
A workspace is the primary isolation boundary. It represents an
organisation, team, or independent operating unit. All data belongs to
exactly one workspace. Cross-workspace access is never permitted through
the API.
A workspace owns:
- Source documents
- Flows (processing pipeline definitions)
- Knowledge cores (stored extraction output)
- Collections (organisational units for extracted knowledge)
### Collection
A collection is an organisational unit within a workspace. It groups
extracted knowledge produced from source documents. A workspace can
have multiple collections, allowing:
- Processing the same documents with different parameters or models.
- Maintaining separate knowledge bases for different purposes.
- Deleting extracted knowledge without deleting source documents.
Collections do not own source documents. A source document exists at the
workspace level and can be processed into multiple collections.
### Source document
A source document (PDF, text file, etc.) is raw input uploaded to the
system. Documents belong to the workspace, not to a specific collection.
This is intentional. A document is an asset that exists independently
of how it is processed. The same PDF might be processed into multiple
collections with different chunking parameters or extraction models.
Tying a document to a single collection would force re-upload for each
collection.
### Flow
A flow defines a processing pipeline: which models to use, what
parameters to apply (chunk size, temperature, etc.), and how processing
services are connected. Flows belong to a workspace.
The processing services themselves (document-decoder, chunker,
embeddings, LLM completion, etc.) are shared infrastructure — they serve
all workspaces. Each flow has its own queues, keeping data from
different workspaces and flows separate as it moves through the
pipeline.
Different workspaces can define different flows. Workspace A might use
GPT-5.2 with a chunk size of 2000, while workspace B uses Claude with a
chunk size of 1000.
### Prompts
Prompts are templates that control how the LLM behaves during knowledge
extraction and query answering. They belong to a workspace, allowing
different workspaces to have different extraction strategies, response
styles, or domain-specific instructions.
### Ontology
An ontology defines the concepts, entities, and relationships that the
extraction pipeline looks for in source documents. Ontologies belong to
a workspace. A medical workspace might define ontologies around diseases,
symptoms, and treatments, while a legal workspace defines ontologies
around statutes, precedents, and obligations.
### Schemas
Schemas define structured data types for extraction. They specify what
fields to extract, their types, and how they relate. Schemas belong to
a workspace, as different workspaces extract different structured
information from their documents.
### Tools, tool services, and MCP tools
Tools define capabilities available to agents: what actions they can
take, what external services they can call. Tool services configure how
tools connect to backend services. MCP tools configure connections to
remote MCP servers, including authentication tokens. All belong to a
workspace.
### Agent patterns and agent task types
Agent patterns define agent behaviour strategies (how an agent reasons,
what steps it follows). Agent task types define the kinds of tasks
agents can perform. Both belong to a workspace, as different workspaces
may have different agent configurations.
### Token costs
Token cost definitions specify pricing for LLM token usage per model.
These belong to a workspace since different workspaces may use different
models or have different billing arrangements.
### Flow blueprints
Flow blueprints are templates for creating flows. They define the
default pipeline structure and parameters. Blueprints belong to a
workspace, allowing workspaces to define custom processing templates.
### Parameter types
Parameter types define the kinds of parameters that flows accept (e.g.
"llm-model", "temperature"), including their defaults and validation
rules. They belong to a workspace since workspaces that define custom
flows need to define the parameter types those flows use.
### Interface descriptions
Interface descriptions define the connection points of a flow — what
queues and topics it uses. They belong to a workspace since they
describe workspace-owned flows.
### Knowledge core
A knowledge core is a stored snapshot of extracted knowledge (triples
and graph embeddings). Knowledge cores belong to a workspace and can be
loaded into any collection within that workspace.
Knowledge cores serve as a portable extraction output. You process
documents through a flow, the pipeline produces triples and embeddings,
and the results can be stored as a knowledge core. That core can later
be loaded into a different collection or reloaded after a collection is
cleared.
### Extracted knowledge
Extracted knowledge is the live, queryable content within a collection:
triples in the knowledge graph, graph embeddings, and document
embeddings. It is the product of processing source documents through a
flow into a specific collection.
Extracted knowledge is scoped to a workspace and a collection. It
cannot exist without both.
### Processing record
A processing record tracks which source document was processed, through
which flow, into which collection. It links the source document
(workspace-scoped) to the extracted knowledge (workspace + collection
scoped).
## Ownership summary
| Artefact | Owned by | Shared across collections? |
|----------|----------|---------------------------|
| Workspaces | Global (platform) | N/A |
| User accounts | Global (platform) | N/A |
| API keys | Global (platform) | N/A |
| Source documents | Workspace | Yes |
| Flows | Workspace | N/A |
| Flow blueprints | Workspace | N/A |
| Prompts | Workspace | N/A |
| Ontologies | Workspace | N/A |
| Schemas | Workspace | N/A |
| Tools | Workspace | N/A |
| Tool services | Workspace | N/A |
| MCP tools | Workspace | N/A |
| Agent patterns | Workspace | N/A |
| Agent task types | Workspace | N/A |
| Token costs | Workspace | N/A |
| Parameter types | Workspace | N/A |
| Interface descriptions | Workspace | N/A |
| Knowledge cores | Workspace | Yes — can be loaded into any collection |
| Collections | Workspace | N/A |
| Extracted knowledge | Workspace + collection | No |
| Processing records | Workspace + collection | No |
## Scoping summary
### Global (system-level)
A small number of artefacts exist outside any workspace:
- **Workspace registry** — the list of workspaces itself
- **User accounts** — users reference a workspace but are not owned by
one
- **API keys** — belong to users, not workspaces
These are managed by the IAM layer and exist at the platform level.
### Workspace-owned
All other configuration and data is workspace-owned:
- Flow definitions and parameters
- Flow blueprints
- Prompts
- Ontologies
- Schemas
- Tools, tool services, and MCP tools
- Agent patterns and agent task types
- Token costs
- Parameter types
- Interface descriptions
- Collection definitions
- Knowledge cores
- Source documents
- Collections and their extracted knowledge
## Relationship between artefacts
```
Platform (global)
|
+-- Workspaces
| |
+-- User accounts (each assigned to a workspace)
| |
+-- API keys (belong to users)
Workspace
|
+-- Source documents (uploaded, unprocessed)
|
+-- Flows (pipeline definitions: models, parameters, queues)
|
+-- Flow blueprints (templates for creating flows)
|
+-- Prompts (LLM instruction templates)
|
+-- Ontologies (entity and relationship definitions)
|
+-- Schemas (structured data type definitions)
|
+-- Tools, tool services, MCP tools (agent capabilities)
|
+-- Agent patterns and agent task types (agent behaviour)
|
+-- Token costs (LLM pricing per model)
|
+-- Parameter types (flow parameter definitions)
|
+-- Interface descriptions (flow connection points)
|
+-- Knowledge cores (stored extraction snapshots)
|
+-- Collections
|
+-- Extracted knowledge (triples, embeddings)
|
+-- Processing records (links documents to collections)
```
A typical workflow:
1. A source document is uploaded to the workspace.
2. A flow defines how to process it (which models, what parameters).
3. The document is processed through the flow into a collection.
4. Processing records track what was processed.
5. Extracted knowledge (triples, embeddings) is queryable within the
collection.
6. Optionally, the extracted knowledge is stored as a knowledge core
for later reuse.
## Implementation notes
The current codebase uses a `user` field in message metadata and storage
partition keys to identify the workspace. The `collection` field
identifies the collection within that workspace. The IAM spec describes
how the gateway maps authenticated credentials to a workspace identity
and sets these fields.
For details on how each storage backend implements this scoping, see:
- [Entity-Centric Graph](entity-centric-graph.md) — Cassandra KG schema
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)
- [Collection Management](collection-management.md)
### Known inconsistencies in current implementation
- **Pipeline intermediate tables** do not include collection in their
partition keys. Re-processing the same document into a different
collection may overwrite intermediate state.
- **Processing metadata** stores collection in the row payload but not
in the partition key, making collection-based queries inefficient.
- **Upload sessions** are keyed by upload ID, not workspace. The
gateway should validate workspace ownership before allowing
operations on upload sessions.
## References
- [Identity and Access Management](iam.md)
- [Collection Management](collection-management.md)
- [Entity-Centric Graph](entity-centric-graph.md)
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)
- [Multi-Tenant Support](multi-tenant-support.md)

View file

@ -20,8 +20,8 @@ Defines shared service processors that are instantiated once per flow blueprint.
```json
"class": {
"service-name:{class}": {
"request": "queue-pattern:{class}",
"response": "queue-pattern:{class}",
"request": "queue-pattern:{workspace}:{class}",
"response": "queue-pattern:{workspace}:{class}",
"settings": {
"setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}"
@ -31,11 +31,11 @@ Defines shared service processors that are instantiated once per flow blueprint.
```
**Characteristics:**
- Shared across all flow instances of the same class
- Shared across all flow instances of the same class within a workspace
- Typically expensive or stateless services (LLMs, embedding models)
- Use `{class}` template variable for queue naming
- Use `{workspace}` and `{class}` template variables for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}`
- Examples: `embeddings:{workspace}:{class}`, `text-completion:{workspace}:{class}`
### 2. Flow Section
Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors.
@ -43,8 +43,8 @@ Defines flow-specific processors that are instantiated for each individual flow
```json
"flow": {
"processor-name:{id}": {
"input": "queue-pattern:{id}",
"output": "queue-pattern:{id}",
"input": "queue-pattern:{workspace}:{id}",
"output": "queue-pattern:{workspace}:{id}",
"settings": {
"setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}"
@ -56,9 +56,9 @@ Defines flow-specific processors that are instantiated for each individual flow
**Characteristics:**
- Unique instance per flow
- Handle flow-specific data and state
- Use `{id}` template variable for queue naming
- Use `{workspace}` and `{id}` template variables for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}`
- Examples: `chunker:{workspace}:{id}`, `pdf-decoder:{workspace}:{id}`
### 3. Interfaces Section
Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication.
@ -68,8 +68,8 @@ Interfaces can take two forms:
**Fire-and-Forget Pattern** (single queue):
```json
"interfaces": {
"document-load": "persistent://tg/flow/document-load:{id}",
"triples-store": "persistent://tg/flow/triples-store:{id}"
"document-load": "persistent://tg/flow/{workspace}:document-load:{id}",
"triples-store": "persistent://tg/flow/{workspace}:triples-store:{id}"
}
```
@ -77,8 +77,8 @@ Interfaces can take two forms:
```json
"interfaces": {
"embeddings": {
"request": "non-persistent://tg/request/embeddings:{class}",
"response": "non-persistent://tg/response/embeddings:{class}"
"request": "non-persistent://tg/request/{workspace}:embeddings:{class}",
"response": "non-persistent://tg/response/{workspace}:embeddings:{class}"
}
}
```
@ -117,6 +117,16 @@ Additional information about the flow blueprint:
### System Variables
#### {workspace}
- Replaced with the workspace identifier
- Isolates queue names between workspaces so that two workspaces
starting the same flow do not share queues
- Must be included in all queue name patterns to ensure workspace
isolation
- Example: `ws-acme`, `ws-globex`
- All blueprint templates must include `{workspace}` in queue name
patterns
#### {id}
- Replaced with the unique flow instance identifier
- Creates isolated resources for each flow

858
docs/tech-specs/iam.md Normal file
View file

@ -0,0 +1,858 @@
---
layout: default
title: "Identity and Access Management"
parent: "Tech Specs"
---
# Identity and Access Management
## Problem Statement
TrustGraph has no meaningful identity or access management. The system
relies on a single shared gateway token for authentication and an
honour-system `user` query parameter for data isolation. This creates
several problems:
- **No user identity.** There are no user accounts, no login, and no way
to know who is making a request. The `user` field in message metadata
is a caller-supplied string with no validation — any client can claim
to be any user.
- **No access control.** A valid gateway token grants unrestricted access
to every endpoint, every user's data, every collection, and every
administrative operation. There is no way to limit what an
authenticated caller can do.
- **No credential isolation.** All callers share one static token. There
is no per-user credential, no token expiration, and no rotation
mechanism. Revoking access means changing the shared token, which
affects all callers.
- **Data isolation is unenforced.** Storage backends (Cassandra, Neo4j,
Qdrant) filter queries by `user` and `collection`, but the gateway
does not prevent a caller from specifying another user's identity.
Cross-user data access is trivial.
- **No audit trail.** There is no logging of who accessed what. Without
user identity, audit logging is impossible.
These gaps make the system unsuitable for multi-user deployments,
multi-tenant SaaS, or any environment where access needs to be
controlled or audited.
## Current State
### Authentication
The API gateway supports a single shared token configured via the
`GATEWAY_SECRET` environment variable or `--api-token` CLI argument. If
unset, authentication is disabled entirely. When enabled, every HTTP
endpoint requires an `Authorization: Bearer <token>` header. WebSocket
connections pass the token as a query parameter.
Implementation: `trustgraph-flow/trustgraph/gateway/auth.py`
```python
class Authenticator:
def __init__(self, token=None, allow_all=False):
self.token = token
self.allow_all = allow_all
def permitted(self, token, roles):
if self.allow_all: return True
if self.token != token: return False
return True
```
The `roles` parameter is accepted but never evaluated. All authenticated
requests have identical privileges.
MCP tool configurations support an optional per-tool `auth-token` for
service-to-service authentication with remote MCP servers. These are
static, system-wide tokens — not per-user credentials. See
[mcp-tool-bearer-token.md](mcp-tool-bearer-token.md) for details.
### User identity
The `user` field is passed explicitly by the caller as a query parameter
(e.g. `?user=trustgraph`) or set by CLI tools. It flows through the
system in the core `Metadata` dataclass:
```python
@dataclass
class Metadata:
id: str = ""
root: str = ""
user: str = ""
collection: str = ""
```
There is no user registration, login, user database, or session
management.
### Data isolation
The `user` + `collection` pair is used at the storage layer to partition
data:
- **Cassandra**: queries filter by `user` and `collection` columns
- **Neo4j**: queries filter by `user` and `collection` properties
- **Qdrant**: vector search filters by `user` and `collection` metadata
| Layer | Isolation mechanism | Enforced by |
|-------|-------------------|-------------|
| Gateway | Single shared token | `Authenticator` class |
| Message metadata | `user` + `collection` fields | Caller (honour system) |
| Cassandra | Column filters on `user`, `collection` | Query layer |
| Neo4j | Property filters on `user`, `collection` | Query layer |
| Qdrant | Metadata filters on `user`, `collection` | Query layer |
| Pub/sub topics | Per-flow topic namespacing | Flow service |
The storage-layer isolation depends on all queries correctly filtering by
`user` and `collection`. There is no gateway-level enforcement preventing
a caller from querying another user's data by passing a different `user`
parameter.
### Configuration and secrets
| Setting | Source | Default | Purpose |
|---------|--------|---------|---------|
| `GATEWAY_SECRET` | Env var | Empty (auth disabled) | Gateway bearer token |
| `--api-token` | CLI arg | None | Gateway bearer token (overrides env) |
| `PULSAR_API_KEY` | Env var | None | Pub/sub broker auth |
| MCP `auth-token` | Config service | None | Per-tool MCP server auth |
No secrets are encrypted at rest. The gateway token and MCP tokens are
stored and transmitted in plaintext (aside from any transport-layer
encryption such as TLS).
### Capabilities that do not exist
- Per-user authentication (JWT, OAuth, SAML, API keys per user)
- User accounts or user management
- Role-based access control (RBAC)
- Attribute-based access control (ABAC)
- Per-user or per-workspace API keys
- Token expiration or rotation
- Session management
- Per-user rate limiting
- Audit logging of user actions
- Permission checks preventing cross-user data access
- Multi-workspace credential isolation
### Key files
| File | Purpose |
|------|---------|
| `trustgraph-flow/trustgraph/gateway/auth.py` | Authenticator class |
| `trustgraph-flow/trustgraph/gateway/service.py` | Gateway init, token config |
| `trustgraph-flow/trustgraph/gateway/endpoint/*.py` | Per-endpoint auth checks |
| `trustgraph-base/trustgraph/schema/core/metadata.py` | `Metadata` dataclass with `user` field |
## Technical Design
### Design principles
- **Auth at the edge.** The gateway is the single enforcement point.
Internal services trust the gateway and do not re-authenticate.
This avoids distributing credential validation across dozens of
microservices.
- **Identity from credentials, not from callers.** The gateway derives
user identity from authentication credentials. Callers can no longer
self-declare their identity via query parameters.
- **Workspace isolation by default.** Every authenticated user belongs to
a workspace. All data operations are scoped to that workspace.
Cross-workspace access is not possible through the API.
- **Extensible API contract.** The API accepts an optional workspace
parameter on every request. This allows the same protocol to support
single-workspace deployments today and multi-workspace extensions in
the future without breaking changes.
- **Simple roles, not fine-grained permissions.** A small number of
predefined roles controls what operations a user can perform. This is
sufficient for the current API surface and avoids the complexity of
per-resource permission management.
### Authentication
The gateway supports two credential types. Both are carried as a Bearer
token in the `Authorization` header for HTTP requests. The gateway
distinguishes them by format.
For WebSocket connections, credentials are not passed in the URL or
headers. Instead, the client authenticates after connecting by sending
an auth message as the first frame:
```
Client: opens WebSocket to /api/v1/socket
Server: accepts connection (unauthenticated state)
Client: sends {"type": "auth", "token": "tg_abc123..."}
Server: validates token
success → {"type": "auth-ok", "workspace": "acme"}
failure → {"type": "auth-failed", "error": "invalid token"}
```
The server rejects all non-auth messages until authentication succeeds.
The socket remains open on auth failure, allowing the client to retry
with a different token without reconnecting. The client can also send
a new auth message at any time to re-authenticate — for example, to
refresh an expiring JWT or to switch workspace. The
resolved identity (user, workspace, roles) is updated on each
successful auth.
#### API keys
For programmatic access: CLI tools, scripts, and integrations.
- Opaque tokens (e.g. `tg_a1b2c3d4e5f6...`). Not JWTs — short,
simple, easy to paste into CLI tools and headers.
- Each user has one or more API keys.
- Keys are stored hashed (SHA-256 with salt) in the IAM service. The
plaintext key is returned once at creation time and cannot be
retrieved afterwards.
- Keys can be revoked individually without affecting other users.
- Keys optionally have an expiry date. Expired keys are rejected.
On each request, the gateway resolves an API key by:
1. Hashing the token.
2. Checking a local cache (hash → user/workspace/roles).
3. On cache miss, calling the IAM service to resolve.
4. Caching the result with a short TTL (e.g. 60 seconds).
Revoked keys stop working when the cache entry expires. No push
invalidation is needed.
#### JWTs (login sessions)
For interactive access via the UI or WebSocket connections.
- A user logs in with username and password. The gateway forwards the
request to the IAM service, which validates the credentials and
returns a signed JWT.
- The JWT carries the user ID, workspace, and roles as claims.
- The gateway validates JWTs locally using the IAM service's public
signing key — no service call needed on subsequent requests.
- Token expiry is enforced by standard JWT validation at the time the
request (or WebSocket connection) is made.
- For long-lived WebSocket connections, the JWT is validated at connect
time only. The connection remains authenticated for its lifetime.
The IAM service manages the signing key. The gateway fetches the public
key at startup (or on first JWT encounter) and caches it.
#### Login endpoint
```
POST /api/v1/auth/login
{
"username": "alice",
"password": "..."
}
→ {
"token": "eyJ...",
"expires": "2026-04-20T19:00:00Z"
}
```
The gateway forwards this to the IAM service, which validates
credentials and returns a signed JWT. The gateway returns the JWT to
the caller.
#### IAM service delegation
The gateway stays thin. Its authentication logic is:
1. Extract Bearer token from header (or query param for WebSocket).
2. If the token has JWT format (dotted structure), validate the
signature locally and extract claims.
3. Otherwise, treat as an API key: hash it and check the local cache.
On cache miss, call the IAM service to resolve.
4. If neither succeeds, return 401.
All user management, key management, credential validation, and token
signing logic lives in the IAM service. The gateway is a generic
enforcement point that can be replaced without changing the IAM
service.
#### No legacy token support
The existing `GATEWAY_SECRET` shared token is removed. All
authentication uses API keys or JWTs. On first start, the bootstrap
process creates a default workspace and admin user with an initial API
key.
### User identity
A user belongs to exactly one workspace. The design supports extending
this to multi-workspace access in the future (see
[Extension points](#extension-points)).
A user record contains:
| Field | Type | Description |
|-------|------|-------------|
| `id` | string | Unique user identifier (UUID) |
| `name` | string | Display name |
| `email` | string | Email address (optional) |
| `workspace` | string | Workspace the user belongs to |
| `roles` | list[string] | Assigned roles (e.g. `["reader"]`) |
| `enabled` | bool | Whether the user can authenticate |
| `created` | datetime | Account creation timestamp |
The `workspace` field maps to the existing `user` field in `Metadata`.
This means the storage-layer isolation (Cassandra, Neo4j, Qdrant
filtering by `user` + `collection`) works without changes — the gateway
sets the `user` metadata field to the authenticated user's workspace.
### Workspaces
A workspace is an isolated data boundary. Users belong to a workspace,
and all data operations are scoped to it. Workspaces map to the existing
`user` field in `Metadata` and the corresponding Cassandra keyspace,
Qdrant collection prefix, and Neo4j property filters.
| Field | Type | Description |
|-------|------|-------------|
| `id` | string | Unique workspace identifier |
| `name` | string | Display name |
| `enabled` | bool | Whether the workspace is active |
| `created` | datetime | Creation timestamp |
All data operations are scoped to a workspace. The gateway determines
the effective workspace for each request as follows:
1. If the request includes a `workspace` parameter, validate it against
the user's assigned workspace.
- If it matches, use it.
- If it does not match, return 403. (This could be extended to
check a workspace access grant list.)
2. If no `workspace` parameter is provided, use the user's assigned
workspace.
The gateway sets the `user` field in `Metadata` to the effective
workspace ID, replacing the caller-supplied `?user=` query parameter.
This design ensures forward compatibility. Clients that pass a
workspace parameter will work unchanged if multi-workspace support is
added later. Requests for an unassigned workspace get a clear 403
rather than silent misbehaviour.
### Roles and access control
Three roles with fixed permissions:
| Role | Data operations | Admin operations | System |
|------|----------------|-----------------|--------|
| `reader` | Query knowledge graph, embeddings, RAG | None | None |
| `writer` | All reader operations + load documents, manage collections | None | None |
| `admin` | All writer operations | Config, flows, collection management, user management | Metrics |
Role checks happen at the gateway before dispatching to backend
services. Each endpoint declares the minimum role required:
| Endpoint pattern | Minimum role |
|-----------------|--------------|
| `GET /api/v1/socket` (queries) | `reader` |
| `POST /api/v1/librarian` | `writer` |
| `POST /api/v1/flow/*/import/*` | `writer` |
| `POST /api/v1/config` | `admin` |
| `GET /api/v1/flow/*` | `admin` |
| `GET /api/metrics` | `admin` |
Roles are hierarchical: `admin` implies `writer`, which implies
`reader`.
### IAM service
The IAM service is a new backend service that manages all identity and
access data. It is the authority for users, workspaces, API keys, and
credentials. The gateway delegates to it.
#### Data model
```
iam_workspaces (
id text PRIMARY KEY,
name text,
enabled boolean,
created timestamp
)
iam_users (
id text PRIMARY KEY,
workspace text,
name text,
email text,
password_hash text,
roles set<text>,
enabled boolean,
created timestamp
)
iam_api_keys (
key_hash text PRIMARY KEY,
user_id text,
name text,
expires timestamp,
created timestamp
)
```
A secondary index on `iam_api_keys.user_id` supports listing a user's
keys.
#### Responsibilities
- User CRUD (create, list, update, disable)
- Workspace CRUD (create, list, update, disable)
- API key management (create, revoke, list)
- API key resolution (hash → user/workspace/roles)
- Credential validation (username/password → signed JWT)
- JWT signing key management (initialise, rotate)
- Bootstrap (create default workspace and admin user on first start)
#### Communication
The IAM service communicates via the standard request/response pub/sub
pattern, the same as the config service. The gateway calls it to
resolve API keys and to handle login requests. User management
operations (create user, revoke key, etc.) also go through the IAM
service.
### Gateway changes
The current `Authenticator` class is replaced with a thin authentication
middleware that delegates to the IAM service:
For HTTP requests:
1. Extract Bearer token from the `Authorization` header.
2. If the token has JWT format (dotted structure):
- Validate signature locally using the cached public key.
- Extract user ID, workspace, and roles from claims.
3. Otherwise, treat as an API key:
- Hash the token and check the local cache.
- On cache miss, call the IAM service to resolve.
- Cache the result (user/workspace/roles) with a short TTL.
4. If neither succeeds, return 401.
5. If the user or workspace is disabled, return 403.
6. Check the user's role against the endpoint's minimum role. If
insufficient, return 403.
7. Resolve the effective workspace:
- If the request includes a `workspace` parameter, validate it
against the user's assigned workspace. Return 403 on mismatch.
- If no `workspace` parameter, use the user's assigned workspace.
8. Set the `user` field in the request context to the effective
workspace ID. This propagates through `Metadata` to all downstream
services.
For WebSocket connections:
1. Accept the connection in an unauthenticated state.
2. Wait for an auth message (`{"type": "auth", "token": "..."}`).
3. Validate the token using the same logic as steps 2-7 above.
4. On success, attach the resolved identity to the connection and
send `{"type": "auth-ok", ...}`.
5. On failure, send `{"type": "auth-failed", ...}` but keep the
socket open.
6. Reject all non-auth messages until authentication succeeds.
7. Accept new auth messages at any time to re-authenticate.
### CLI changes
CLI tools authenticate with API keys:
- `--api-key` argument on all CLI tools, replacing `--api-token`.
- `tg-create-workspace`, `tg-list-workspaces` for workspace management.
- `tg-create-user`, `tg-list-users`, `tg-disable-user` for user
management.
- `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` for
key management.
- `--workspace` argument on tools that operate on workspace-scoped
data.
- The API key is passed as a Bearer token in the same way as the
current shared token, so the transport protocol is unchanged.
### Audit logging
With user identity established, the gateway logs:
- Timestamp, user ID, workspace, endpoint, HTTP method, response status.
- Audit logs are written to the standard logging output (structured
JSON). Integration with external log aggregation (Loki, ELK) is a
deployment concern, not an application concern.
### Config service changes
All configuration is workspace-scoped (see
[data-ownership-model.md](data-ownership-model.md)). The config service
needs to support this.
#### Schema change
The config table adds workspace as a key dimension:
```
config (
workspace text,
class text,
key text,
value text,
PRIMARY KEY ((workspace, class), key)
)
```
#### Request format
Config requests add a `workspace` field at the request level. The
existing `(type, key)` structure is unchanged within each workspace.
**Get:**
```json
{
"operation": "get",
"workspace": "workspace-a",
"keys": [{"type": "prompt", "key": "rag-prompt"}]
}
```
**Put:**
```json
{
"operation": "put",
"workspace": "workspace-a",
"values": [{"type": "prompt", "key": "rag-prompt", "value": "..."}]
}
```
**List (all keys of a type within a workspace):**
```json
{
"operation": "list",
"workspace": "workspace-a",
"type": "prompt"
}
```
**Delete:**
```json
{
"operation": "delete",
"workspace": "workspace-a",
"keys": [{"type": "prompt", "key": "rag-prompt"}]
}
```
The workspace is set by:
- **Gateway** — from the authenticated user's workspace for API-facing
requests.
- **Internal services** — explicitly, based on `Metadata.user` from
the message being processed, or `_system` for operational config.
#### System config namespace
Processor-level operational config (logging levels, connection strings,
resource limits) is not workspace-specific. This stays in a reserved
`_system` workspace that is not associated with any user workspace.
Services read system config at startup without needing a workspace
context.
#### Config change notifications
The config notify mechanism pushes change notifications via pub/sub
when config is updated. A single update may affect multiple workspaces
and multiple config types. The notification message carries a dict of
changes keyed by config type, with each value being the list of
affected workspaces:
```json
{
"version": 42,
"changes": {
"prompt": ["workspace-a", "workspace-b"],
"schema": ["workspace-a"]
}
}
```
System config changes use the reserved `_system` workspace:
```json
{
"version": 43,
"changes": {
"logging": ["_system"]
}
}
```
This structure is keyed by type because handlers register by type. A
handler registered for `prompt` looks up `"prompt"` directly and gets
the list of affected workspaces — no iteration over unrelated types.
#### Config change handlers
The current `on_config` hook mechanism needs two modes to support shared
processing services:
- **Workspace-scoped handlers** — notify when a config type changes in a
specific workspace. The handler looks up its registered type in the
changes dict and checks if its workspace is in the list. Used by the
gateway and by services that serve a single workspace.
- **Global handlers** — notify when a config type changes in any
workspace. The handler looks up its registered type in the changes
dict and gets the full list of affected workspaces. Used by shared
processing services (prompt-rag, agent manager, etc.) that serve all
workspaces. Each workspace in the list tells the handler which cache
entry to update rather than reloading everything.
#### Per-workspace config caching
Shared services that handle messages from multiple workspaces maintain a
per-workspace config cache. When a message arrives, the service looks up
the config for the workspace identified in `Metadata.user`. If the
workspace is not yet cached, the service fetches its config on demand.
Config change notifications update the relevant cache entry.
### Flow and queue isolation
Flows are workspace-owned. When two workspaces start flows with the same
name and blueprint, their queues must be separate to prevent data
mixing.
Flow blueprint templates currently use `{id}` (flow instance ID) and
`{class}` (blueprint name) as template variables in queue names. A new
`{workspace}` variable is added so queue names include the workspace:
**Current queue names (no workspace isolation):**
```
flow:tg:document-load:{id} → flow:tg:document-load:default
request:tg:embeddings:{class} → request:tg:embeddings:everything
```
**With workspace isolation:**
```
flow:tg:{workspace}:document-load:{id} → flow:tg:ws-a:document-load:default
request:tg:{workspace}:embeddings:{class} → request:tg:ws-a:embeddings:everything
```
The flow service substitutes `{workspace}` from the authenticated
workspace when starting a flow, the same way it substitutes `{id}` and
`{class}` today.
Processing services are shared infrastructure — they consume from
workspace-specific queues but are not themselves workspace-aware. The
workspace is carried in `Metadata.user` on every message, so services
know which workspace's data they are processing.
Blueprint templates need updating to include `{workspace}` in all queue
name patterns. For migration, the flow service can inject the workspace
into queue names automatically if the template does not include
`{workspace}`, defaulting to the legacy behaviour for existing
blueprints.
See [flow-class-definition.md](flow-class-definition.md) for the full
blueprint template specification.
### What changes and what doesn't
**Changes:**
| Component | Change |
|-----------|--------|
| `gateway/auth.py` | Replace `Authenticator` with new auth middleware |
| `gateway/service.py` | Initialise IAM client, configure JWT validation |
| `gateway/endpoint/*.py` | Add role requirement per endpoint |
| Metadata propagation | Gateway sets `user` from workspace, ignores query param |
| Config service | Add workspace dimension to config schema |
| Config table | `PRIMARY KEY ((workspace, class), key)` |
| Config request/response schema | Add `workspace` field |
| Config notify messages | Include workspace ID in change notifications |
| `on_config` handlers | Support workspace-scoped and global modes |
| Shared services | Per-workspace config caching |
| Flow blueprints | Add `{workspace}` template variable to queue names |
| Flow service | Substitute `{workspace}` when starting flows |
| CLI tools | New user management commands, `--api-key` argument |
| Cassandra schema | New `iam_workspaces`, `iam_users`, `iam_api_keys` tables |
**Does not change:**
| Component | Reason |
|-----------|--------|
| Internal service-to-service pub/sub | Services trust the gateway |
| `Metadata` dataclass | `user` field continues to carry workspace identity |
| Storage-layer isolation | Same `user` + `collection` filtering |
| Message serialisation | No schema changes |
### Migration
This is a breaking change. Existing deployments must be reconfigured:
1. `GATEWAY_SECRET` is removed. Authentication requires API keys or
JWT login tokens.
2. The `?user=` query parameter is removed. Workspace identity comes
from authentication.
3. On first start, the IAM service bootstraps a default workspace and
admin user. The initial API key is output to the service log.
4. Operators create additional workspaces and users via CLI tools.
5. Flow blueprints must be updated to include `{workspace}` in queue
name patterns.
6. Config data must be migrated to include the workspace dimension.
## Extension points
The design includes deliberate extension points for future capabilities.
These are not implemented but the architecture does not preclude them:
- **Multi-workspace access.** Users could be granted access to
additional workspaces beyond their primary assignment. The workspace
validation step checks a grant list instead of a single assignment.
- **Rules-based access control.** A separate access control service
could evaluate fine-grained policies (per-collection permissions,
operation-level restrictions, time-based access). The gateway
delegates authorisation decisions to this service.
- **External identity provider integration.** SAML, LDAP, and OIDC
flows (group mapping, claims-based role assignment) could be added
to the IAM service.
- **Cross-workspace administration.** A `superadmin` role for platform
operators who manage multiple workspaces.
- **Delegated workspace provisioning.** APIs for programmatic workspace
creation and user onboarding.
These extensions are additive — they extend the validation logic
without changing the request/response protocol. The gateway can be
replaced with an alternative implementation that supports these
capabilities while the IAM service and backend services remain
unchanged.
## Implementation plan
Workspace support is a prerequisite for auth — users are assigned to
workspaces, config is workspace-scoped, and flows use workspace in
queue names. Implementing workspaces first allows the structural changes
to be tested end-to-end without auth complicating debugging.
### Phase 1: Workspace support (no auth)
All workspace-scoped data and processing changes. The system works with
workspaces but no authentication — callers pass workspace as a
parameter, honour system. This allows full end-to-end testing: multiple
workspaces with separate flows, config, queues, and data.
#### Config service
- Update config client API to accept a workspace parameter on all
requests
- Update config storage schema to add workspace as a key dimension
- Update config notification API to report changes as a dict of
type → workspace list
- Update the processor base class to understand workspaces in config
notifications (workspace-scoped and global handler modes)
- Update all processors to implement workspace-aware config handling
(per-workspace config caching, on-demand fetch)
#### Flow and queue isolation
- Update flow blueprints to include `{workspace}` in all queue name
patterns
- Update the flow service to substitute `{workspace}` when starting
flows
- Update all built-in blueprints to include `{workspace}`
#### CLI tools (workspace support)
- Add `--workspace` argument to CLI tools that operate on
workspace-scoped data
- Add `tg-create-workspace`, `tg-list-workspaces` commands
### Phase 2: Authentication and access control
With workspaces working, add the IAM service and lock down the gateway.
#### IAM service
A new service handling identity and access management on behalf of the
API gateway:
- Add workspace table support (CRUD, enable/disable)
- Add user table support (CRUD, enable/disable, workspace assignment)
- Add roles support (role assignment, role validation)
- Add API key support (create, revoke, list, hash storage)
- Add ability to initialise a JWT signing key for token grants
- Add token grant endpoint: user/password login returns a signed JWT
- Add bootstrap/initialisation mechanism: ability to set the signing
key and create the initial workspace + admin user on first start
#### API gateway integration
- Add IAM middleware to the API gateway replacing the current
`Authenticator`
- Add local JWT validation (public key from IAM service)
- Add API key resolution with local cache (hash → user/workspace/roles,
cache miss calls IAM service, short TTL)
- Add login endpoint forwarding to IAM service
- Add workspace resolution: validate requested workspace against user
assignment
- Add role-based endpoint access checks
- Add user management API endpoints (forwarded to IAM service)
- Add audit logging (user ID, workspace, endpoint, method, status)
- WebSocket auth via first-message protocol (auth message after
connect, socket stays open on failure, re-auth supported)
#### CLI tools (auth support)
- Add `tg-create-user`, `tg-list-users`, `tg-disable-user` commands
- Add `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key`
commands
- Replace `--api-token` with `--api-key` on existing CLI tools
#### Bootstrap and cutover
- Create default workspace and admin user on first start if IAM tables
are empty
- Remove `GATEWAY_SECRET` and `?user=` query parameter support
## Design Decisions
### IAM data store
IAM data is stored in dedicated Cassandra tables owned by the IAM
service, not in the config service. Reasons:
- **Security isolation.** The config service has a broad, generic
protocol. An access control failure on the config service could
expose credentials. A dedicated IAM service with a purpose-built
protocol limits the attack surface and makes security auditing
clearer.
- **Data model fit.** IAM needs indexed lookups (API key hash → user,
list keys by user). The config service's `(workspace, type, key) →
value` model stores opaque JSON strings with no secondary indexes.
- **Scope.** IAM data is global (workspaces, users, keys). Config is
workspace-scoped. Mixing global and workspace-scoped data in the
same store adds complexity.
- **Audit.** IAM operations (key creation, revocation, login attempts)
are security events that should be logged separately from general
config changes.
## Deferred to future design
- **OIDC integration.** External identity provider support (SAML, LDAP,
OIDC) is left for future implementation. The extension points section
describes where this fits architecturally.
- **API key scoping.** API keys could be scoped to specific collections
within a workspace rather than granting workspace-wide access. To be
designed when the need arises.
- **tg-init-trustgraph** only initialises a single workspace.
## References
- [Data Ownership and Information Separation](data-ownership-model.md)
- [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md)
- [Multi-Tenant Support Specification](multi-tenant-support.md)
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)

View file

@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
"""Test basic agent integration with structured query tool"""
# Arrange - Load tool configuration
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Create agent request
request = AgentRequest(
@ -119,6 +119,7 @@ Args: {
# Mock flow parameter in agent_processor.on_request
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -146,7 +147,7 @@ Args: {
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
"""Test agent handling of structured query errors"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.",
@ -199,6 +200,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -221,7 +223,7 @@ Args: {
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
"""Test agent using structured query in multi-step reasoning"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.",
@ -279,6 +281,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -313,7 +316,7 @@ Args: {
}
}
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
request = AgentRequest(
question="Query the sales data for recent transactions.",
@ -371,6 +374,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -394,10 +398,10 @@ Args: {
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
"""Test that structured query tool arguments are properly validated"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Check that the tool was registered with correct arguments
tools = agent_processor.agent.tools
tools = agent_processor.agents["default"].tools
assert "structured-query" in tools
structured_tool = tools["structured-query"]
@ -414,7 +418,7 @@ Args: {
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
"""Test that structured query results are properly formatted for agent consumption"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Get customer information and format it nicely.",
@ -482,6 +486,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)

View file

@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
)
# Set up schemas
proc.schemas = sample_schemas
proc.schemas = {"default": dict(sample_schemas)}
# Mock the client method
proc.client = MagicMock()
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
}
# Act - Update configuration
await integration_processor.on_schema_config(new_schema_config, "v2")
await integration_processor.on_schema_config("default", new_schema_config, "v2")
# Arrange - Test query using new schema
request = QuestionToStructuredQueryRequest(
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
await integration_processor.on_message(msg, consumer, flow)
# Assert
assert "inventory" in integration_processor.schemas
assert "inventory" in integration_processor.schemas["default"]
response_call = flow_response.send.call_args
response = response_call[0][0]
assert response.detected_schemas == ["inventory"]
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
graphql_generation_template="custom-graphql-generator"
)
custom_processor.schemas = sample_schemas
custom_processor.schemas = {"default": dict(sample_schemas)}
custom_processor.client = MagicMock()
request = QuestionToStructuredQueryRequest(
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
)
integration_processor.schemas.update(large_schema_set)
integration_processor.schemas["default"].update(large_schema_set)
request = QuestionToStructuredQueryRequest(
question="Show me data from table_05 and table_12",
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response

View file

@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.mark.asyncio
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Act
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
# Verify customer schema
customer_schema = processor.schemas["customer_records"]
customer_schema = ws_schemas["customer_records"]
assert customer_schema.name == "customer_records"
assert len(customer_schema.fields) == 4
# Verify product schema
product_schema = processor.schemas["product_catalog"]
product_schema = ws_schemas["product_catalog"]
assert product_schema.name == "product_catalog"
assert len(product_schema.fields) == 4
@ -237,7 +239,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic customer data chunk
metadata = Metadata(
@ -304,7 +306,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic product data chunk
metadata = Metadata(
@ -368,7 +370,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create multiple test chunks
chunks_data = [
@ -431,19 +433,21 @@ class TestObjectExtractionServiceIntegration:
"customer_records": integration_config["schema"]["customer_records"]
}
}
await processor.on_schema_config(initial_config, version=1)
assert len(processor.schemas) == 1
assert "customer_records" in processor.schemas
assert "product_catalog" not in processor.schemas
await processor.on_schema_config("default", initial_config, version=1)
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 1
assert "customer_records" in ws_schemas
assert "product_catalog" not in ws_schemas
# Act - Reload with full configuration
await processor.on_schema_config(integration_config, version=2)
await processor.on_schema_config("default", integration_config, version=2)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
@pytest.mark.asyncio
async def test_error_resilience_integration(self, integration_config):
@ -474,10 +478,11 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
failing_flow.side_effect = failing_context_router
failing_flow.workspace = "default"
processor.flow = failing_flow
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test")
@ -510,7 +515,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create chunk with rich metadata
original_metadata = Metadata(

View file

@ -87,6 +87,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.fixture
@ -109,7 +110,7 @@ class TestPromptStreaming:
def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support"""
processor = MagicMock()
processor.manager = mock_prompt_manager
processor.managers = {"default": mock_prompt_manager}
processor.config_key = "prompt"
# Bind the actual on_request method
@ -248,6 +249,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",
@ -341,6 +343,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",

View file

@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@ -125,8 +136,8 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
await processor.on_schema_config("default", config, version=1)
assert "customer_records" in processor.schemas["default"]
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
@ -149,7 +160,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify Cassandra interactions
assert mock_cluster.connect.called
@ -209,8 +220,8 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
await processor.on_schema_config("default", config, version=1)
assert len(processor.schemas["default"]) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
@ -233,7 +244,7 @@ class TestRowsCassandraIntegration:
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
@ -256,15 +267,17 @@ class TestRowsCassandraIntegration:
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
processor.schemas["default"] = {
"indexed_data": RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
@ -282,7 +295,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -342,7 +355,7 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
@ -376,7 +389,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
@ -396,10 +409,12 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
processor.schemas["default"] = {
"empty_test": RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
# Process empty batch object
empty_obj = ExtractedObject(
@ -413,7 +428,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
@ -428,14 +443,16 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
processor.schemas["default"] = {
"map_test": RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
@ -448,7 +465,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -473,13 +490,15 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
processor.schemas["default"] = {
"partition_test": RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection"),
@ -492,7 +511,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list

View file

@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
"""Test schema configuration loading and GraphQL schema generation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Verify schemas were loaded
assert len(processor.schemas) == 2
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
"""Test Cassandra connection and dynamic table creation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra
processor.connect_cassandra()
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
"""Test inserting data and querying via GraphQL"""
# Load schema and connect
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Setup test data
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
"""Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup)
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "test_user"
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_error_handling(self, processor, sample_schema_config):
"""Test GraphQL error handling for invalid queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Test invalid field query
invalid_query = '''
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_message_processing_integration(self, processor, sample_schema_config):
"""Test full message processing workflow"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create mock message
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_concurrent_queries(self, processor, sample_schema_config):
"""Test handling multiple concurrent GraphQL queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create multiple query tasks
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(initial_config, version=1)
await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1
assert "simple" in processor.schemas
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(updated_config, version=2)
await processor.on_schema_config("default", updated_config, version=2)
# Verify updated schemas
assert len(processor.schemas) == 2
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_large_result_set_handling(self, processor, sample_schema_config):
"""Test handling of large query result sets"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "large_test_user"
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
}
}
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Measure query execution time
start_time = time.time()

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -58,6 +61,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -129,6 +133,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -148,6 +155,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()

View file

@ -241,7 +241,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
@ -260,6 +260,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +281,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
@ -298,6 +299,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +319,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
@ -330,6 +332,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +353,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
@ -362,6 +365,7 @@ class TestToolServiceOnRequest:
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +380,8 @@ class TestToolServiceOnRequest:
received = {}
async def capture_invoke(name, params):
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
@ -390,6 +395,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(

View file

@ -1,17 +1,14 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
- on_config_notify version comparison, type/workspace matching
- fetch_and_apply_config retry logic over per-workspace fetches
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify:
@pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
msg = _notify_msg(3, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
msg = _notify_msg(5, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
msg = _notify_msg(2, {"schema": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={"key": "value"},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None)
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called()
all_handler.assert_called_once()
all_handler.assert_not_called()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
async def test_multi_workspace_notify_invokes_handler_per_ws(
self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
mock_config = {}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
assert handler.call_count == 2
called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
side_effect=RuntimeError("Connection failed"),
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
async def test_applies_config_per_workspace(self, processor):
"""Startup fetch invokes handler once per workspace affected."""
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert h.call_count == 2
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def test_handler_without_types_skipped_at_startup(self, processor):
"""Handlers registered without types fetch nothing at startup."""
typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch():
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
mock_client = AsyncMock()
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1"
assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {}
assert flow.consumer == {}
assert flow.parameter == {}
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock()
consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start())
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value"

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)

View file

@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
assert flow_name in processor.flows
assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', flow_name, processor, flow_defn
'test-processor', flow_name, "default", processor, flow_defn
)
mock_flow.start.assert_called_once()
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
await processor.start_flow(flow_name, {'config': 'test-config'})
await processor.start_flow("default", flow_name, {'config': 'test-config'})
await processor.stop_flow(flow_name)
await processor.stop_flow("default", flow_name)
assert flow_name not in processor.flows
assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once()
@with_async_processor_patches
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
processor = FlowProcessor(**config)
await processor.stop_flow('non-existent-flow')
await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {}
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert 'test-flow' in processor.flows
assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', processor,
'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'}
)
mock_flow.start.assert_called_once()
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
'other-data': 'some-value'
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data1, version=1)
await processor.on_configure_flows("default", config_data1, version=1)
config_data2 = {
'processor:test-processor': {
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data2, version=2)
await processor.on_configure_flows("default", config_data2, version=2)
assert 'flow1' not in processor.flows
assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once()
assert 'flow2' in processor.flows
assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once()
@with_async_processor_patches

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com',
config_type='prompt',
format_type='json',
token=None
token=None,
workspace='default'
)
def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/',
config_type='prompt',
format_type='text',
token=None
token=None,
workspace='default'
)
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt',
key='template-1',
format_type='json',
token=None
token=None,
workspace='default'
)
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt',
key='new-template',
value='Custom prompt: {input}',
token=None
token=None,
workspace='default'
)
def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt',
key='stdin-template',
value=stdin_content,
token=None
token=None,
workspace='default'
)
def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com',
config_type='prompt',
key='old-template',
token=None
token=None,
workspace='default'
)

View file

@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters
mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
token="test-token",
workspace="default"
)
# Verify bulk client was obtained

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None)
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery:

View file

@ -28,10 +28,12 @@ def mock_flow_config():
"""Mock flow configuration."""
mock_config = Mock()
mock_config.flows = {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
"test-user": {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
}
}
}
}

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
assert 'customers' in processor.schemas["default"]
assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {}
assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
@ -325,14 +325,16 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
processor.schemas["default"] = {
'customers': RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
}
metadata = MagicMock()
metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver:
"""Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Notify with newer version fetches each affected workspace."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1
msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
"""Older-version notifies are ignored."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=3, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0
msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
"""Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch
assert len(fetch_calls) == 0
msg = _notify(2, {"prompt": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
"""on_config_notify swallows exceptions from message decode."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
"flow2": '{"name": "test_flow_2"}',
}
}
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
"flow1": '{"name": "test_flow_1"}',
}
}
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Empty workspace config clears any local flow state."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {}
assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
"""start_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.start_flow = Mock()
handler1.start_flow = AsyncMock()
handler2 = Mock()
handler2.start_flow = Mock()
handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
handler1.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions"""
"""Handler exceptions in start_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data)
handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers"""
"""stop_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.stop_flow = Mock()
handler1.stop_flow = AsyncMock()
handler2 = Mock()
handler2.stop_flow = Mock()
handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
handler1.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions"""
"""Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data)
handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task')
@pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
"flow3": '{"name": "test_flow_3"}',
}
}
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_calls.append((workspace, id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
ws_flows = config_receiver.flows["default"]
assert "flow1" not in ws_flows
assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"
assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data)
assert "flow1" in manager.flows
assert manager.flows["flow1"] == flow_data
await manager.start_flow("default", "flow1", flow_data)
assert ("default", "flow1") in manager.flows
assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio
async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data
await manager.stop_flow("flow1", flow_data)
assert "flow1" not in manager.flows
manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("default", "flow1", flow_data)
assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper"""
@ -275,12 +275,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -290,7 +290,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
params = {"flow": "test_flow", "kind": "triples"}
result = await manager.process_flow_import("ws", "running", params)
@ -326,12 +326,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
mock_dispatchers.__contains__.return_value = False
@ -348,12 +348,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result"
@pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}}
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -443,7 +443,7 @@ class TestDispatcherManager:
}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
@ -452,23 +452,23 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
request_queue="agent_request_queue",
response_queue="agent_response_queue",
timeout=120,
consumer="api-gateway-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request"
consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-default-test_flow-agent-request"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result"
@pytest.mark.asyncio
@ -479,26 +479,26 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-load": {"flow": "text_load_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
mock_rr_dispatchers.__contains__.return_value = False
mock_sender_dispatchers.__contains__.return_value = True
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock()
mock_dispatcher.process = AsyncMock(return_value="sender_result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
@ -506,9 +506,9 @@ class TestDispatcherManager:
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result"
@pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-completion": {"request": "req", "response": "resp"}
}
}
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"invalid-kind": {"request": "req", "response": "resp"}
}
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5)
])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -239,7 +239,7 @@ def _make_processor(tools=None):
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
processor.agents = {"default": agent}
processor.aggregator = MagicMock()
return processor
@ -254,6 +254,7 @@ def _make_flow():
return producers[name]
flow = MagicMock(side_effect=factory)
flow.workspace = "default"
return flow
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -433,7 +434,7 @@ class TestAgentPlanDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -537,7 +538,7 @@ class TestAgentSupervisorDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)

View file

@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.query_cassandra = MagicMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
}
# Process config
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert "customer" in processor.schemas["default"]
schema = processor.schemas["default"]["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
@ -147,24 +146,26 @@ class TestRowsGraphQLQueryLogic:
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
# Verify per-workspace schema builder was created and graphql schema built
assert "default" in processor.schema_builders
assert "default" in processor.graphql_schemas
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
@ -173,8 +174,8 @@ class TestRowsGraphQLQueryLogic:
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
graphql_schema.execute.assert_called_once()
call_args = graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value']
@ -190,7 +191,8 @@ class TestRowsGraphQLQueryLogic:
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error
@ -212,9 +214,10 @@ class TestRowsGraphQLQueryLogic:
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
@ -259,6 +262,7 @@ class TestRowsGraphQLQueryLogic:
# Mock flow
mock_flow = MagicMock()
mock_flow.workspace = "default"
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
@ -267,6 +271,7 @@ class TestRowsGraphQLQueryLogic:
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,

View file

@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert
assert "test_schema" in processor.schemas
schema = processor.schemas["test_schema"]
assert "test_schema" in processor.schemas["default"]
schema = processor.schemas["default"]["test_schema"]
assert schema.name == "test_schema"
assert schema.description == "Test schema"
assert len(schema.fields) == 2
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert - bad schema should be ignored
assert "bad_schema" not in processor.schemas
assert "bad_schema" not in processor.schemas.get("default", {})
def test_processor_initialization(self, mock_pulsar_client):
"""Test processor initialization with correct specifications"""

View file

@ -101,7 +101,7 @@ def service(mock_schemas):
taskgroup=MagicMock(),
id="test-processor"
)
service.schemas = mock_schemas
service.schemas = {"default": dict(mock_schemas)}
return service
@ -109,6 +109,7 @@ def service(mock_schemas):
def mock_flow():
"""Create mock flow with prompt service"""
flow = MagicMock()
flow.workspace = "default"
prompt_request_flow = AsyncMock()
flow.return_value.request = prompt_request_flow
return flow, prompt_request_flow

View file

@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
}
# Process configuration
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert "customer_records" in processor.schemas["default"]
schema = processor.schemas["default"]["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
@ -165,14 +176,16 @@ class TestRowsCassandraStorageLogic:
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -205,7 +218,7 @@ class TestRowsCassandraStorageLogic:
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert was executed
mock_async_execute.assert_called()
@ -230,14 +243,16 @@ class TestRowsCassandraStorageLogic:
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
"default": {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -267,7 +282,7 @@ class TestRowsCassandraStorageLogic:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per indexed field: id, category, status)
assert mock_async_execute.call_count == 3
@ -290,13 +305,15 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
"default": {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -331,7 +348,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert mock_async_execute.call_count == 3
@ -349,10 +366,12 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
"default": {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -381,7 +400,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@ -446,19 +465,21 @@ class TestPartitionRegistration:
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
@ -473,7 +494,7 @@ class TestPartitionRegistration:
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -50,7 +50,7 @@ class Api:
token: Optional bearer token for authentication
"""
def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None):
def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None, workspace: str = "default"):
"""
Initialize the TrustGraph API client.
@ -82,6 +82,7 @@ class Api:
self.timeout = timeout
self.token = token
self.workspace = workspace
# Lazy initialization for new clients
self._socket_client = None
@ -137,7 +138,7 @@ class Api:
config.put([ConfigValue(type="llm", key="model", value="gpt-4")])
```
"""
return Config(api=self)
return Config(api=self, workspace=self.workspace)
def knowledge(self):
"""
@ -191,6 +192,12 @@ class Api:
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
# Ensure every REST request carries the workspace so services can
# scope their behaviour. Callers that already set workspace in the
# payload (e.g. Library client) take precedence.
if isinstance(request, dict) and "workspace" not in request:
request = {**request, "workspace": self.workspace}
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=request, timeout=self.timeout, headers=headers)
@ -297,7 +304,10 @@ class Api:
from . socket_client import SocketClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._socket_client = SocketClient(base_url, self.timeout, self.token)
self._socket_client = SocketClient(
base_url, self.timeout, self.token,
workspace=self.workspace,
)
return self._socket_client
def bulk(self):
@ -417,7 +427,10 @@ class Api:
from . async_socket_client import AsyncSocketClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token)
self._async_socket_client = AsyncSocketClient(
base_url, self.timeout, self.token,
workspace=self.workspace,
)
return self._async_socket_client
def async_bulk(self):

View file

@ -22,10 +22,14 @@ class AsyncSocketClient:
Or call connect()/aclose() manually.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]):
def __init__(
self, url: str, timeout: int, token: Optional[str],
workspace: str = "default",
):
self.url = self._convert_to_ws_url(url)
self.timeout = timeout
self.token = token
self.workspace = workspace
self._request_counter = 0
self._socket = None
self._connect_cm = None
@ -117,6 +121,7 @@ class AsyncSocketClient:
try:
message = {
"id": request_id,
"workspace": self.workspace,
"service": service,
"request": request
}
@ -149,6 +154,7 @@ class AsyncSocketClient:
try:
message = {
"id": request_id,
"workspace": self.workspace,
"service": service,
"request": request
}

View file

@ -2,11 +2,9 @@
TrustGraph Collection Management
This module provides interfaces for managing data collections in TrustGraph.
Collections provide logical grouping and isolation for documents and knowledge
graph data.
Collections provide logical grouping within a workspace.
"""
import datetime
import logging
from . types import CollectionMetadata
@ -18,10 +16,9 @@ class Collection:
"""
Collection management client.
Provides methods for managing data collections, including listing,
updating metadata, and deleting collections. Collections organize
documents and knowledge graph data into logical groupings for
isolation and access control.
Provides methods for managing data collections within the configured
workspace, including listing, updating metadata, and deleting
collections.
"""
def __init__(self, api):
@ -45,45 +42,20 @@ class Collection:
"""
return self.api.request(f"collection-management", request)
def list_collections(self, user, tag_filter=None):
def list_collections(self, tag_filter=None):
"""
List all collections for a user.
Retrieves metadata for all collections owned by the specified user,
with optional filtering by tags.
List all collections in this workspace.
Args:
user: User identifier
tag_filter: Optional list of tags to filter collections (default: None)
tag_filter: Optional list of tags to filter collections
Returns:
list[CollectionMetadata]: List of collection metadata objects
Raises:
ProtocolException: If response format is invalid
Example:
```python
collection = api.collection()
# List all collections
all_colls = collection.list_collections(user="trustgraph")
for coll in all_colls:
print(f"{coll.collection}: {coll.name}")
print(f" Description: {coll.description}")
print(f" Tags: {', '.join(coll.tags)}")
# List collections with specific tags
research_colls = collection.list_collections(
user="trustgraph",
tag_filter=["research", "published"]
)
```
"""
input = {
"operation": "list-collections",
"user": user,
"workspace": self.api.workspace,
}
if tag_filter:
@ -92,7 +64,6 @@ class Collection:
object = self.request(input)
try:
# Handle case where collections might be None or missing
if object is None or "collections" not in object:
return []
@ -102,7 +73,6 @@ class Collection:
return [
CollectionMetadata(
user = v["user"],
collection = v["collection"],
name = v["name"],
description = v["description"],
@ -114,15 +84,11 @@ class Collection:
logger.error("Failed to parse collection list response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def update_collection(self, user, collection, name=None, description=None, tags=None):
def update_collection(self, collection, name=None, description=None, tags=None):
"""
Update collection metadata.
Updates the name, description, and/or tags for an existing collection.
Only provided fields are updated; others remain unchanged.
Args:
user: User identifier
collection: Collection identifier
name: New collection name (optional)
description: New collection description (optional)
@ -130,35 +96,11 @@ class Collection:
Returns:
CollectionMetadata: Updated collection metadata, or None if not found
Raises:
ProtocolException: If response format is invalid
Example:
```python
collection_api = api.collection()
# Update collection metadata
updated = collection_api.update_collection(
user="trustgraph",
collection="default",
name="Default Collection",
description="Main data collection for general use",
tags=["default", "production"]
)
# Update only specific fields
updated = collection_api.update_collection(
user="trustgraph",
collection="research",
description="Updated description"
)
```
"""
input = {
"operation": "update-collection",
"user": user,
"workspace": self.api.workspace,
"collection": collection,
}
@ -175,7 +117,6 @@ class Collection:
if "collections" in object and object["collections"]:
v = object["collections"][0]
return CollectionMetadata(
user = v["user"],
collection = v["collection"],
name = v["name"],
description = v["description"],
@ -186,37 +127,23 @@ class Collection:
logger.error("Failed to parse collection update response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def delete_collection(self, user, collection):
def delete_collection(self, collection):
"""
Delete a collection.
Removes a collection and all its associated data from the system.
Args:
user: User identifier
collection: Collection identifier to delete
Returns:
dict: Empty response object
Example:
```python
collection_api = api.collection()
# Delete a collection
collection_api.delete_collection(
user="trustgraph",
collection="old-collection"
)
```
"""
input = {
"operation": "delete-collection",
"user": user,
"workspace": self.api.workspace,
"collection": collection,
}
object = self.request(input)
self.request(input)
return {}
return {}

View file

@ -21,14 +21,16 @@ class Config:
and list operations.
"""
def __init__(self, api):
def __init__(self, api, workspace="default"):
"""
Initialize Config client.
Args:
api: Parent Api instance for making requests
workspace: Workspace to scope all config operations to
"""
self.api = api
self.workspace = workspace
def request(self, request):
"""
@ -75,9 +77,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "get",
"workspace": self.workspace,
"keys": [
{ "type": k.type, "key": k.key }
for k in keys
@ -123,9 +125,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "put",
"workspace": self.workspace,
"values": [
{ "type": v.type, "key": v.key, "value": v.value }
for v in values
@ -157,9 +159,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "delete",
"workspace": self.workspace,
"keys": [
{ "type": v.type, "key": v.key }
for v in keys
@ -195,9 +197,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "list",
"workspace": self.workspace,
"type": type,
}
@ -235,9 +237,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "getvalues",
"workspace": self.workspace,
"type": type,
}
@ -255,6 +257,46 @@ class Config:
except:
raise ProtocolException(f"Response not formatted correctly")
def get_values_all_workspaces(self, type):
"""
Get all configuration values of a given type across all workspaces.
Unlike get_values(), this is not scoped to a single workspace
it returns every entry of the given type in the system. Each
returned ConfigValue includes its workspace field. Used by
shared processors to load type-scoped config at startup.
Args:
type: Configuration type (e.g. "prompt", "schema")
Returns:
list[ConfigValue]: Values across all workspaces; each has
its workspace field populated.
Raises:
ProtocolException: If response format is invalid
"""
input = {
"operation": "getvalues-all-ws",
"type": type,
}
object = self.request(input)
try:
return [
ConfigValue(
type = v["type"],
key = v["key"],
value = v["value"],
workspace = v.get("workspace", ""),
)
for v in object["values"]
]
except Exception:
raise ProtocolException("Response not formatted correctly")
def all(self):
"""
Get complete configuration and version.
@ -279,9 +321,9 @@ class Config:
```
"""
# The input consists of system and prompt strings
input = {
"operation": "config"
"operation": "config",
"workspace": self.workspace,
}
object = self.request(input)

View file

@ -115,72 +115,32 @@ class Flow:
return FlowInstance(api=self, id=id)
def list_blueprints(self):
"""
List all available flow blueprints.
"""List blueprints in the current workspace."""
Returns:
list[str]: List of blueprint names
Example:
```python
blueprints = api.flow().list_blueprints()
print(blueprints) # ['default', 'custom-flow', ...]
```
"""
# The input consists of system and prompt strings
input = {
"operation": "list-blueprints",
"workspace": self.api.workspace,
}
return self.request(request = input)["blueprint-names"]
def get_blueprint(self, blueprint_name):
"""
Get a flow blueprint definition by name.
"""Get a flow blueprint definition by name."""
Args:
blueprint_name: Name of the blueprint to retrieve
Returns:
dict: Blueprint definition as a dictionary
Example:
```python
blueprint = api.flow().get_blueprint("default")
print(blueprint) # Blueprint configuration
```
"""
# The input consists of system and prompt strings
input = {
"operation": "get-blueprint",
"workspace": self.api.workspace,
"blueprint-name": blueprint_name,
}
return json.loads(self.request(request = input)["blueprint-definition"])
def put_blueprint(self, blueprint_name, definition):
"""
Create or update a flow blueprint.
"""Create or update a flow blueprint."""
Args:
blueprint_name: Name for the blueprint
definition: Blueprint definition dictionary
Example:
```python
definition = {
"services": ["text-completion", "graph-rag"],
"parameters": {"model": "gpt-4"}
}
api.flow().put_blueprint("my-blueprint", definition)
```
"""
# The input consists of system and prompt strings
input = {
"operation": "put-blueprint",
"workspace": self.api.workspace,
"blueprint-name": blueprint_name,
"blueprint-definition": json.dumps(definition),
}
@ -188,96 +148,43 @@ class Flow:
self.request(request = input)
def delete_blueprint(self, blueprint_name):
"""
Delete a flow blueprint.
"""Delete a flow blueprint."""
Args:
blueprint_name: Name of the blueprint to delete
Example:
```python
api.flow().delete_blueprint("old-blueprint")
```
"""
# The input consists of system and prompt strings
input = {
"operation": "delete-blueprint",
"workspace": self.api.workspace,
"blueprint-name": blueprint_name,
}
self.request(request = input)
def list(self):
"""
List all active flow instances.
"""List flow instances in the current workspace."""
Returns:
list[str]: List of flow instance IDs
Example:
```python
flows = api.flow().list()
print(flows) # ['default', 'flow-1', 'flow-2', ...]
```
"""
# The input consists of system and prompt strings
input = {
"operation": "list-flows",
"workspace": self.api.workspace,
}
return self.request(request = input)["flow-ids"]
def get(self, id):
"""
Get the definition of a running flow instance.
"""Get the definition of a flow instance."""
Args:
id: Flow instance ID
Returns:
dict: Flow instance definition
Example:
```python
flow_def = api.flow().get("default")
print(flow_def)
```
"""
# The input consists of system and prompt strings
input = {
"operation": "get-flow",
"workspace": self.api.workspace,
"flow-id": id,
}
return json.loads(self.request(request = input)["flow"])
def start(self, blueprint_name, id, description, parameters=None):
"""
Start a new flow instance from a blueprint.
"""Start a new flow instance from a blueprint."""
Args:
blueprint_name: Name of the blueprint to instantiate
id: Unique identifier for the flow instance
description: Human-readable description
parameters: Optional parameters dictionary
Example:
```python
api.flow().start(
blueprint_name="default",
id="my-flow",
description="My custom flow",
parameters={"model": "gpt-4"}
)
```
"""
# The input consists of system and prompt strings
input = {
"operation": "start-flow",
"workspace": self.api.workspace,
"flow-id": id,
"blueprint-name": blueprint_name,
"description": description,
@ -289,21 +196,11 @@ class Flow:
self.request(request = input)
def stop(self, id):
"""
Stop a running flow instance.
"""Stop a running flow instance."""
Args:
id: Flow instance ID to stop
Example:
```python
api.flow().stop("my-flow")
```
"""
# The input consists of system and prompt strings
input = {
"operation": "stop-flow",
"workspace": self.api.workspace,
"flow-id": id,
}
@ -349,6 +246,13 @@ class FlowInstance:
Returns:
dict: Service response
"""
# Inject workspace so the gateway can route to the right
# workspace's flow. If already present, keep the caller's value.
if isinstance(request, dict) and "workspace" not in request:
request = {
"workspace": self.api.api.workspace,
**request,
}
return self.api.request(path = f"{self.id}/{path}", request = request)
def text_completion(self, system, prompt):

View file

@ -63,105 +63,50 @@ class Knowledge:
"""
return self.api.request(f"knowledge", request)
def list_kg_cores(self, user="trustgraph"):
def list_kg_cores(self):
"""
List all available knowledge graph cores.
Retrieves the IDs of all KG cores available for the specified user.
Args:
user: User identifier (default: "trustgraph")
List all available knowledge graph cores in this workspace.
Returns:
list[str]: List of KG core identifiers
Example:
```python
knowledge = api.knowledge()
# List available KG cores
cores = knowledge.list_kg_cores(user="trustgraph")
print(f"Available KG cores: {cores}")
```
"""
# The input consists of system and prompt strings
input = {
"operation": "list-kg-cores",
"user": user,
"workspace": self.api.workspace,
}
return self.request(request = input)["ids"]
def delete_kg_core(self, id, user="trustgraph"):
def delete_kg_core(self, id):
"""
Delete a knowledge graph core.
Removes a KG core from storage. This does not affect currently loaded
cores in flows.
Delete a knowledge graph core in this workspace.
Args:
id: KG core identifier to delete
user: User identifier (default: "trustgraph")
Example:
```python
knowledge = api.knowledge()
# Delete a KG core
knowledge.delete_kg_core(id="medical-kb-v1", user="trustgraph")
```
"""
# The input consists of system and prompt strings
input = {
"operation": "delete-kg-core",
"user": user,
"workspace": self.api.workspace,
"id": id,
}
self.request(request = input)
def load_kg_core(self, id, user="trustgraph", flow="default",
collection="default"):
def load_kg_core(self, id, flow="default", collection="default"):
"""
Load a knowledge graph core into a flow.
Makes a KG core available for use in queries and RAG operations within
the specified flow and collection.
Args:
id: KG core identifier to load
user: User identifier (default: "trustgraph")
flow: Flow instance to load into (default: "default")
collection: Collection to associate with (default: "default")
Example:
```python
knowledge = api.knowledge()
# Load a medical knowledge base into the default flow
knowledge.load_kg_core(
id="medical-kb-v1",
user="trustgraph",
flow="default",
collection="medical"
)
# Now the flow can use this KG core for RAG queries
flow = api.flow().id("default")
response = flow.graph_rag(
query="What are the symptoms of diabetes?",
user="trustgraph",
collection="medical"
)
```
"""
# The input consists of system and prompt strings
input = {
"operation": "load-kg-core",
"user": user,
"workspace": self.api.workspace,
"id": id,
"flow": flow,
"collection": collection,
@ -169,35 +114,18 @@ class Knowledge:
self.request(request = input)
def unload_kg_core(self, id, user="trustgraph", flow="default"):
def unload_kg_core(self, id, flow="default"):
"""
Unload a knowledge graph core from a flow.
Removes a KG core from active use in the specified flow, freeing
resources while keeping the core available in storage.
Args:
id: KG core identifier to unload
user: User identifier (default: "trustgraph")
flow: Flow instance to unload from (default: "default")
Example:
```python
knowledge = api.knowledge()
# Unload a KG core when no longer needed
knowledge.unload_kg_core(
id="medical-kb-v1",
user="trustgraph",
flow="default"
)
```
"""
# The input consists of system and prompt strings
input = {
"operation": "unload-kg-core",
"user": user,
"workspace": self.api.workspace,
"id": id,
"flow": flow,
}

View file

@ -94,7 +94,7 @@ class Library:
return self.api.request(f"librarian", request)
def add_document(
self, document, id, metadata, user, title, comments,
self, document, id, metadata, title, comments,
kind="text/plain", tags=[], on_progress=None,
):
"""
@ -176,7 +176,6 @@ class Library:
document=document,
id=id,
metadata=metadata,
user=user,
title=title,
comments=comments,
kind=kind,
@ -213,6 +212,7 @@ class Library:
input = {
"operation": "add-document",
"workspace": self.api.workspace,
"document-metadata": {
"id": id,
"time": int(time.time()),
@ -220,7 +220,7 @@ class Library:
"title": title,
"comments": comments,
"metadata": triples,
"user": user,
"workspace": self.api.workspace,
"tags": tags
},
"content": base64.b64encode(document).decode("utf-8"),
@ -229,7 +229,7 @@ class Library:
return self.request(input)
def _add_document_chunked(
self, document, id, metadata, user, title, comments,
self, document, id, metadata, title, comments,
kind, tags, on_progress=None,
):
"""
@ -245,13 +245,14 @@ class Library:
# Begin upload session
begin_request = {
"operation": "begin-upload",
"workspace": self.api.workspace,
"document-metadata": {
"id": id,
"time": int(time.time()),
"kind": kind,
"title": title,
"comments": comments,
"user": user,
"workspace": self.api.workspace,
"tags": tags,
},
"total-size": total_size,
@ -279,10 +280,10 @@ class Library:
chunk_request = {
"operation": "upload-chunk",
"workspace": self.api.workspace,
"upload-id": upload_id,
"chunk-index": chunk_index,
"content": base64.b64encode(chunk_data).decode("utf-8"),
"user": user,
}
chunk_response = self.request(chunk_request)
@ -298,8 +299,8 @@ class Library:
# Complete upload
complete_request = {
"operation": "complete-upload",
"workspace": self.api.workspace,
"upload-id": upload_id,
"user": user,
}
complete_response = self.request(complete_request)
@ -314,8 +315,8 @@ class Library:
try:
abort_request = {
"operation": "abort-upload",
"workspace": self.api.workspace,
"upload-id": upload_id,
"user": user,
}
self.request(abort_request)
logger.info(f"Aborted failed upload {upload_id}")
@ -323,7 +324,7 @@ class Library:
logger.warning(f"Failed to abort upload: {abort_error}")
raise
def get_documents(self, user, include_children=False):
def get_documents(self, include_children=False):
"""
List all documents for a user.
@ -359,7 +360,7 @@ class Library:
input = {
"operation": "list-documents",
"user": user,
"workspace": self.api.workspace,
"include-children": include_children,
}
@ -381,7 +382,7 @@ class Library:
)
for w in v["metadata"]
],
user = v["user"],
workspace = v.get("workspace", ""),
tags = v["tags"],
parent_id = v.get("parent-id", ""),
document_type = v.get("document-type", "source"),
@ -392,7 +393,7 @@ class Library:
logger.error("Failed to parse document list response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def get_document(self, user, id):
def get_document(self, id):
"""
Get metadata for a specific document.
@ -419,7 +420,7 @@ class Library:
input = {
"operation": "get-document",
"user": user,
"workspace": self.api.workspace,
"document-id": id,
}
@ -441,7 +442,7 @@ class Library:
)
for w in doc["metadata"]
],
user = doc["user"],
workspace = doc.get("workspace", ""),
tags = doc["tags"],
parent_id = doc.get("parent-id", ""),
document_type = doc.get("document-type", "source"),
@ -450,7 +451,7 @@ class Library:
logger.error("Failed to parse document response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def update_document(self, user, id, metadata):
def update_document(self, id, metadata):
"""
Update document metadata.
@ -490,8 +491,9 @@ class Library:
input = {
"operation": "update-document",
"workspace": self.api.workspace,
"document-metadata": {
"user": user,
"workspace": self.api.workspace,
"document-id": id,
"time": metadata.time,
"title": metadata.title,
@ -526,14 +528,14 @@ class Library:
)
for w in doc["metadata"]
],
user = doc["user"],
workspace = doc.get("workspace", ""),
tags = doc["tags"]
)
except Exception as e:
logger.error("Failed to parse document update response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def remove_document(self, user, id):
def remove_document(self, id):
"""
Remove a document from the library.
@ -555,7 +557,7 @@ class Library:
input = {
"operation": "remove-document",
"user": user,
"workspace": self.api.workspace,
"document-id": id,
}
@ -565,7 +567,7 @@ class Library:
def start_processing(
self, id, document_id, flow="default",
user="trustgraph", collection="default", tags=[],
collection="default", tags=[],
):
"""
Start a document processing workflow.
@ -602,12 +604,13 @@ class Library:
input = {
"operation": "add-processing",
"workspace": self.api.workspace,
"processing-metadata": {
"id": id,
"document-id": document_id,
"time": int(time.time()),
"flow": flow,
"user": user,
"workspace": self.api.workspace,
"collection": collection,
"tags": tags,
}
@ -618,7 +621,7 @@ class Library:
return {}
def stop_processing(
self, id, user="trustgraph",
self, id,
):
"""
Stop a running document processing job.
@ -641,15 +644,15 @@ class Library:
input = {
"operation": "remove-processing",
"workspace": self.api.workspace,
"processing-id": id,
"user": user,
}
object = self.request(input)
return {}
def get_processings(self, user="trustgraph"):
def get_processings(self):
"""
List all active document processing jobs.
@ -681,7 +684,7 @@ class Library:
input = {
"operation": "list-processing",
"user": user,
"workspace": self.api.workspace,
}
object = self.request(input)
@ -693,7 +696,7 @@ class Library:
document_id = v["document-id"],
time = datetime.datetime.fromtimestamp(v["time"]),
flow = v["flow"],
user = v["user"],
workspace = v.get("workspace", ""),
collection = v["collection"],
tags = v["tags"],
)
@ -705,7 +708,7 @@ class Library:
# Chunked upload management methods
def get_pending_uploads(self, user):
def get_pending_uploads(self):
"""
List all pending (in-progress) uploads for a user.
@ -731,14 +734,14 @@ class Library:
"""
input = {
"operation": "list-uploads",
"user": user,
"workspace": self.api.workspace,
}
response = self.request(input)
return response.get("upload-sessions", [])
def get_upload_status(self, upload_id, user):
def get_upload_status(self, upload_id):
"""
Get the status of a specific upload.
@ -774,13 +777,13 @@ class Library:
"""
input = {
"operation": "get-upload-status",
"workspace": self.api.workspace,
"upload-id": upload_id,
"user": user,
}
return self.request(input)
def abort_upload(self, upload_id, user):
def abort_upload(self, upload_id):
"""
Abort an in-progress upload.
@ -801,13 +804,13 @@ class Library:
"""
input = {
"operation": "abort-upload",
"workspace": self.api.workspace,
"upload-id": upload_id,
"user": user,
}
return self.request(input)
def resume_upload(self, upload_id, document, user, on_progress=None):
def resume_upload(self, upload_id, document, on_progress=None):
"""
Resume an interrupted upload.
@ -844,7 +847,7 @@ class Library:
```
"""
# Get current status
status = self.get_upload_status(upload_id, user)
status = self.get_upload_status(upload_id)
if status.get("upload-state") == "expired":
raise RuntimeError("Upload session has expired, please start a new upload")
@ -867,10 +870,10 @@ class Library:
chunk_request = {
"operation": "upload-chunk",
"workspace": self.api.workspace,
"upload-id": upload_id,
"chunk-index": chunk_index,
"content": base64.b64encode(chunk_data).decode("utf-8"),
"user": user,
}
self.request(chunk_request)
@ -886,8 +889,8 @@ class Library:
# Complete upload
complete_request = {
"operation": "complete-upload",
"workspace": self.api.workspace,
"upload-id": upload_id,
"user": user,
}
return self.request(complete_request)
@ -895,7 +898,7 @@ class Library:
# Child document methods
def add_child_document(
self, document, id, parent_id, user, title, comments,
self, document, id, parent_id, title, comments,
kind="text/plain", tags=[], metadata=None,
):
"""
@ -964,6 +967,7 @@ class Library:
input = {
"operation": "add-child-document",
"workspace": self.api.workspace,
"document-metadata": {
"id": id,
"time": int(time.time()),
@ -971,7 +975,7 @@ class Library:
"title": title,
"comments": comments,
"metadata": triples,
"user": user,
"workspace": self.api.workspace,
"tags": tags,
"parent-id": parent_id,
"document-type": "extracted",
@ -981,7 +985,7 @@ class Library:
return self.request(input)
def list_children(self, document_id, user):
def list_children(self, document_id):
"""
List all child documents for a given parent document.
@ -1006,8 +1010,8 @@ class Library:
"""
input = {
"operation": "list-children",
"workspace": self.api.workspace,
"document-id": document_id,
"user": user,
}
response = self.request(input)
@ -1028,7 +1032,7 @@ class Library:
)
for w in v.get("metadata", [])
],
user=v["user"],
workspace=v.get("workspace", ""),
tags=v.get("tags", []),
parent_id=v.get("parent-id", ""),
document_type=v.get("document-type", "source"),
@ -1039,7 +1043,7 @@ class Library:
logger.error("Failed to parse children response", exc_info=True)
raise ProtocolException("Response not formatted correctly")
def get_document_content(self, user, id):
def get_document_content(self, id):
"""
Get the content of a document.
@ -1067,7 +1071,7 @@ class Library:
"""
input = {
"operation": "get-document-content",
"user": user,
"workspace": self.api.workspace,
"document-id": id,
}
@ -1076,7 +1080,7 @@ class Library:
return base64.b64decode(content_b64)
def stream_document_to_file(self, user, id, file_path, chunk_size=1024*1024, on_progress=None):
def stream_document_to_file(self, id, file_path, chunk_size=1024*1024, on_progress=None):
"""
Stream document content to a file.
@ -1116,7 +1120,7 @@ class Library:
while True:
input = {
"operation": "stream-document",
"user": user,
"workspace": self.api.workspace,
"document-id": id,
"chunk-index": chunk_index,
"chunk-size": chunk_size,

View file

@ -84,10 +84,14 @@ class SocketClient:
for streaming responses.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(
self, url: str, timeout: int, token: Optional[str],
workspace: str = "default",
) -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace
self._request_counter: int = 0
self._lock: Lock = Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
@ -251,6 +255,7 @@ class SocketClient:
try:
message = {
"id": request_id,
"workspace": self.workspace,
"service": service,
"request": request
}
@ -290,6 +295,7 @@ class SocketClient:
try:
message = {
"id": request_id,
"workspace": self.workspace,
"service": service,
"request": request
}
@ -328,6 +334,7 @@ class SocketClient:
try:
message = {
"id": request_id,
"workspace": self.workspace,
"service": service,
"request": request
}

View file

@ -45,10 +45,13 @@ class ConfigValue:
type: Configuration type/category
key: Specific configuration key
value: Configuration value as string
workspace: Workspace the value belongs to. Only populated for
responses to getvalues-all-ws; empty otherwise.
"""
type : str
key : str
value : str
workspace : str = ""
@dataclasses.dataclass
class DocumentMetadata:
@ -62,7 +65,7 @@ class DocumentMetadata:
title: Document title
comments: Additional comments or description
metadata: List of RDF triples providing structured metadata
user: User/owner identifier
workspace: Workspace the document belongs to
tags: List of tags for categorization
parent_id: Parent document ID for child documents (empty for top-level docs)
document_type: "source" for uploaded documents, "extracted" for derived content
@ -73,7 +76,7 @@ class DocumentMetadata:
title : str
comments : str
metadata : List[Triple]
user : str
workspace : str
tags : List[str]
parent_id : str = ""
document_type : str = "source"
@ -88,7 +91,7 @@ class ProcessingMetadata:
document_id: ID of the document being processed
time: Processing start timestamp
flow: Flow instance handling the processing
user: User identifier
workspace: Workspace the processing job belongs to
collection: Target collection for processed data
tags: List of tags for categorization
"""
@ -96,7 +99,7 @@ class ProcessingMetadata:
document_id : str
time : datetime.datetime
flow : str
user : str
workspace : str
collection : str
tags : List[str]
@ -105,17 +108,15 @@ class CollectionMetadata:
"""
Metadata for a data collection.
Collections provide logical grouping and isolation for documents and
knowledge graph data.
Collections provide logical grouping within a workspace for documents
and knowledge graph data.
Attributes:
user: User/owner identifier
collection: Collection identifier
name: Human-readable collection name
description: Collection description
tags: List of tags for categorization
"""
user : str
collection : str
name : str
description : str

View file

@ -125,21 +125,39 @@ class AsyncProcessor:
response_metrics = config_resp_metrics,
)
async def fetch_config(self):
"""Fetch full config from config service using a short-lived
request/response client. Returns (config, version) or raises."""
client = self._create_config_client()
try:
await client.start()
resp = await client.request(
ConfigRequest(operation="config"),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
return resp.config, resp.version
finally:
await client.stop()
async def _fetch_type_workspace(self, client, workspace, config_type):
"""Fetch config values of a single type within one workspace.
Returns dict of {key: value}."""
resp = await client.request(
ConfigRequest(
operation="getvalues",
workspace=workspace,
type=config_type,
),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
return {v.key: v.value for v in resp.values}
async def _fetch_type_all_workspaces(self, client, config_type):
"""Fetch config values of a single type across all workspaces.
Returns dict of {workspace: {key: value}}."""
resp = await client.request(
ConfigRequest(
operation="getvalues-all-ws",
type=config_type,
),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
grouped = {}
for v in resp.values:
ws = grouped.setdefault(v.workspace, {})
ws[v.key] = v.value
return grouped, resp.version
# This is called to start dynamic behaviour.
# Implements the subscribe-then-fetch pattern to avoid race conditions.
@ -155,21 +173,51 @@ class AsyncProcessor:
# processed by on_config_notify, which does the version check
async def fetch_and_apply_config(self):
"""Fetch full config from config service and apply to all handlers.
Retries until successful config service may not be ready yet."""
"""Startup: for each registered handler, fetch config for all its
types across all workspaces and invoke the handler once per
workspace. Retries until successful config service may not be
ready yet."""
while self.running:
try:
config, version = await self.fetch_config()
client = self._create_config_client()
try:
await client.start()
logger.info(f"Fetched config version {version}")
version = 0
self.config_version = version
for entry in self.config_handlers:
handler_types = entry["types"]
# Apply to all handlers (startup = invoke all)
for entry in self.config_handlers:
await entry["handler"](config, version)
# Handlers registered without types get nothing
# at startup (there is no "all types" fetch).
if not handler_types:
continue
# Group all registered types by workspace:
# {workspace: {type: {key: value}}}
per_ws = {}
for t in handler_types:
type_data, v = \
await self._fetch_type_all_workspaces(
client, t,
)
version = max(version, v)
for ws, kv in type_data.items():
per_ws.setdefault(ws, {})[t] = kv
# Call the handler once per workspace
for ws, config in per_ws.items():
await entry["handler"](ws, config, version)
logger.info(
f"Applied startup config version {version}"
)
self.config_version = version
finally:
await client.stop()
return
@ -204,8 +252,9 @@ class AsyncProcessor:
# Called when a config notify message arrives
async def on_config_notify(self, message, consumer, flow):
notify_version = message.value().version
notify_types = set(message.value().types)
v = message.value()
notify_version = v.version
changes = v.changes # dict of type -> [workspaces]
# Skip if we already have this version or newer
if notify_version <= self.config_version:
@ -215,41 +264,60 @@ class AsyncProcessor:
)
return
# Check if any handler cares about the affected types
if notify_types:
any_interested = False
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None or notify_types & handler_types:
any_interested = True
break
notify_types = set(changes.keys())
if not any_interested:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no handlers for types {notify_types}"
)
self.config_version = notify_version
return
# Filter out handlers that don't care about any of the changed
# types. A handler registered without types never fires on
# notifications (nothing to scope to).
interested = []
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types and notify_types & handler_types:
interested.append(entry)
if not interested:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no handlers for types {notify_types}"
)
self.config_version = notify_version
return
logger.info(
f"Config notify v{notify_version} types={list(notify_types)}, "
f"fetching config..."
f"Config notify v{notify_version} "
f"types={list(notify_types)}, fetching config..."
)
# Fetch full config using short-lived client
try:
config, version = await self.fetch_config()
client = self._create_config_client()
try:
await client.start()
self.config_version = version
for entry in interested:
handler_types = entry["types"]
# Invoke handlers that care about the affected types
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None:
await entry["handler"](config, version)
elif not notify_types or notify_types & handler_types:
await entry["handler"](config, version)
# Build {workspace: {type: {key: value}}} for types
# this handler cares about, where the workspace was
# affected for that type.
per_ws = {}
for t in handler_types:
if t not in changes:
continue
for ws in changes[t]:
kv = await self._fetch_type_workspace(
client, ws, t,
)
per_ws.setdefault(ws, {})[t] = kv
for ws, config in per_ws.items():
await entry["handler"](
ws, config, notify_version,
)
finally:
await client.stop()
self.config_version = notify_version
except Exception as e:
logger.error(

View file

@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor):
await super(ChunkingService, self).start()
await self.librarian.start()
async def get_document_text(self, doc):
async def get_document_text(self, doc, workspace):
"""
Get text content from a TextDocument, fetching from librarian if needed.
Args:
doc: TextDocument with either inline text or document_id
workspace: Workspace for librarian lookup (from flow.workspace)
Returns:
str: The document text content
@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor):
logger.info(f"Fetching document {doc.document_id} from librarian...")
text = await self.librarian.fetch_document_text(
document_id=doc.document_id,
user=doc.metadata.user,
workspace=workspace,
)
logger.info(f"Fetched {len(text)} characters from librarian")
return text

View file

@ -15,114 +15,139 @@ class CollectionConfigHandler:
Storage services should:
1. Inherit from this class along with their service base class
2. Call register_config_handler(self.on_collection_config) in __init__
3. Implement create_collection(user, collection, metadata) method
4. Implement delete_collection(user, collection) method
3. Implement create_collection(workspace, collection, metadata) method
4. Implement delete_collection(workspace, collection) method
"""
def __init__(self, **kwargs):
# Track known collections: {(user, collection): metadata_dict}
# Track known collections: {(workspace, collection): metadata_dict}
self.known_collections: Dict[tuple, dict] = {}
# Pass remaining kwargs up the inheritance chain
super().__init__(**kwargs)
async def on_collection_config(self, config: dict, version: int):
async def on_collection_config(
self, workspace: str, config: dict, version: int
):
"""
Handle config push messages and extract collection information
for a single workspace.
Args:
workspace: Workspace the config applies to
config: Configuration dictionary from ConfigPush message
version: Configuration version number
"""
logger.info(f"Processing collection configuration (version {version})")
logger.info(
f"Processing collection configuration "
f"(version {version}, workspace {workspace})"
)
# Extract collections from config (treat missing key as empty)
# Extract collections from config (treat missing key as empty).
# Each config key IS the collection name — config is already
# partitioned by workspace, so no workspace prefix is needed
# on the key.
collection_config = config.get("collection", {})
# Track which collections we've seen in this config
current_collections: Set[tuple] = set()
# Process each collection in the config
for key, value_json in collection_config.items():
for collection, value_json in collection_config.items():
try:
# Parse user:collection key
if ":" not in key:
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
continue
current_collections.add((workspace, collection))
user, collection = key.split(":", 1)
current_collections.add((user, collection))
# Parse metadata
metadata = json.loads(value_json)
# Check if this is a new collection or updated
collection_key = (user, collection)
if collection_key not in self.known_collections:
logger.info(f"New collection detected: {user}/{collection}")
await self.create_collection(user, collection, metadata)
self.known_collections[collection_key] = metadata
key = (workspace, collection)
if key not in self.known_collections:
logger.info(
f"New collection detected: {workspace}/{collection}"
)
await self.create_collection(
workspace, collection, metadata
)
self.known_collections[key] = metadata
else:
# Collection already exists, update metadata if changed
if self.known_collections[collection_key] != metadata:
logger.info(f"Collection metadata updated: {user}/{collection}")
# Most storage services don't need to do anything for metadata updates
# They just need to know the collection exists
self.known_collections[collection_key] = metadata
if self.known_collections[key] != metadata:
logger.info(
f"Collection metadata updated: "
f"{workspace}/{collection}"
)
self.known_collections[key] = metadata
except Exception as e:
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
logger.error(
f"Error processing collection config for "
f"{workspace}/{collection}: {e}",
exc_info=True,
)
# Find collections that were deleted (in known but not in current)
deleted_collections = set(self.known_collections.keys()) - current_collections
for user, collection in deleted_collections:
logger.info(f"Collection deleted: {user}/{collection}")
# Find collections for THIS workspace that were deleted (in
# known but not in current). Only compare collections owned by
# this workspace — other workspaces' collections are not
# affected by this config update.
known_for_ws = {
(w, c) for (w, c) in self.known_collections.keys()
if w == workspace
}
deleted_collections = known_for_ws - current_collections
for ws, collection in deleted_collections:
logger.info(f"Collection deleted: {ws}/{collection}")
try:
# Remove from known_collections FIRST to immediately reject new writes
# This eliminates race condition with worker threads
del self.known_collections[(user, collection)]
# Physical deletion happens after - worker threads already rejecting writes
await self.delete_collection(user, collection)
# Remove from known_collections FIRST to immediately
# reject new writes
del self.known_collections[(ws, collection)]
await self.delete_collection(ws, collection)
except Exception as e:
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
# If physical deletion failed, should we re-add to known_collections?
# For now, keep it removed - collection is logically deleted per config
logger.error(
f"Error deleting collection {ws}/{collection}: {e}",
exc_info=True,
)
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
logger.debug(
f"Collection config processing complete. "
f"Known collections: {len(self.known_collections)}"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(
self, workspace: str, collection: str, metadata: dict,
):
"""
Create a collection in the storage backend.
Subclasses must implement this method.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
metadata: Collection metadata dictionary
"""
raise NotImplementedError("Storage service must implement create_collection method")
raise NotImplementedError(
"Storage service must implement create_collection method"
)
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""
Delete a collection from the storage backend.
Subclasses must implement this method.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
"""
raise NotImplementedError("Storage service must implement delete_collection method")
raise NotImplementedError(
"Storage service must implement delete_collection method"
)
def collection_exists(self, user: str, collection: str) -> bool:
def collection_exists(self, workspace: str, collection: str) -> bool:
"""
Check if a collection is known to exist
Check if a collection is known to exist.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
Returns:
True if collection exists, False otherwise
"""
return (user, collection) in self.known_collections
return (workspace, collection) in self.known_collections

View file

@ -18,10 +18,11 @@ class ConfigClient(RequestResponse):
)
return resp
async def get(self, type, key, timeout=CONFIG_TIMEOUT):
async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
"""Get a single config value. Returns the value string or None."""
resp = await self._request(
operation="get",
workspace=workspace,
keys=[ConfigKey(type=type, key=key)],
timeout=timeout,
)
@ -29,19 +30,21 @@ class ConfigClient(RequestResponse):
return resp.values[0].value
return None
async def put(self, type, key, value, timeout=CONFIG_TIMEOUT):
async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT):
"""Put a single config value."""
await self._request(
operation="put",
workspace=workspace,
values=[ConfigValue(type=type, key=key, value=value)],
timeout=timeout,
)
async def put_many(self, values, timeout=CONFIG_TIMEOUT):
"""Put multiple config values in a single request.
values is a list of (type, key, value) tuples."""
async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT):
"""Put multiple config values in a single request within a
single workspace. values is a list of (type, key, value) tuples."""
await self._request(
operation="put",
workspace=workspace,
values=[
ConfigValue(type=t, key=k, value=v)
for t, k, v in values
@ -49,19 +52,21 @@ class ConfigClient(RequestResponse):
timeout=timeout,
)
async def delete(self, type, key, timeout=CONFIG_TIMEOUT):
async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
"""Delete a single config key."""
await self._request(
operation="delete",
workspace=workspace,
keys=[ConfigKey(type=type, key=key)],
timeout=timeout,
)
async def delete_many(self, keys, timeout=CONFIG_TIMEOUT):
"""Delete multiple config keys in a single request.
keys is a list of (type, key) tuples."""
async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT):
"""Delete multiple config keys in a single request within a
single workspace. keys is a list of (type, key) tuples."""
await self._request(
operation="delete",
workspace=workspace,
keys=[
ConfigKey(type=t, key=k)
for t, k in keys
@ -69,15 +74,26 @@ class ConfigClient(RequestResponse):
timeout=timeout,
)
async def keys(self, type, timeout=CONFIG_TIMEOUT):
"""List all keys for a config type."""
async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT):
"""List all keys for a config type within a workspace."""
resp = await self._request(
operation="list",
workspace=workspace,
type=type,
timeout=timeout,
)
return resp.directory
async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT):
"""Return the set of distinct workspaces with any config of
the given type."""
resp = await self._request(
operation="getvalues-all-ws",
type=type,
timeout=timeout,
)
return {v.workspace for v in resp.values if v.workspace}
class ConfigClientSpec(RequestResponseSpec):
def __init__(

View file

@ -24,7 +24,10 @@ class ConsumerSpec(Spec):
flow = flow,
backend = processor.pubsub,
topic = definition["topics"][self.name],
subscriber = processor.id + "--" + flow.name + "--" + self.name,
subscriber = (
processor.id + "--" + flow.workspace + "--" +
flow.name + "--" + self.name
),
schema = self.schema,
handler = self.handler,
metrics = consumer_metrics,

View file

@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
logger.debug(f"Handling document embeddings query request {id}...")
docs = await self.query_document_embeddings(request)
docs = await self.query_document_embeddings(
flow.workspace, request,
)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunks=docs, error=None)

View file

@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor):
request = msg.value()
await self.store_document_embeddings(request)
# Workspace comes from the flow the message arrived on.
await self.store_document_embeddings(flow.workspace, request)
except TooManyRequests as e:
raise e

View file

@ -4,15 +4,16 @@ import asyncio
class Flow:
"""
Runtime representation of a deployed flow process.
This class maintains internal processor states and orchestrates
lifecycles (start, stop) for inputs (consumers) and parameters
lifecycles (start, stop) for inputs (consumers) and parameters
that drive data flowing across linked nodes.
"""
def __init__(self, id, flow, processor, defn):
def __init__(self, id, flow, workspace, processor, defn):
self.id = id
self.name = flow
self.workspace = workspace
self.producer = {}

View file

@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor):
)
# Initialise flow information state
# Keyed by (workspace, flow) tuples; each workspace has its own
# set of flow variants for this processor.
self.flows = {}
# These can be overriden by a derived class:
@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor):
def register_specification(self, spec: Any) -> None:
self.specifications.append(spec)
# Start processing for a new flow
async def start_flow(self, flow, defn):
self.flows[flow] = Flow(self.id, flow, self, defn)
await self.flows[flow].start()
logger.info(f"Started flow: {flow}")
# Stop processing for a new flow
async def stop_flow(self, flow):
if flow in self.flows:
await self.flows[flow].stop()
del self.flows[flow]
logger.info(f"Stopped flow: {flow}")
# Start processing for a new flow within a workspace
async def start_flow(self, workspace, flow, defn):
key = (workspace, flow)
self.flows[key] = Flow(self.id, flow, workspace, self, defn)
await self.flows[key].start()
logger.info(f"Started flow: {workspace}/{flow}")
# Event handler - called for a configuration change
async def on_configure_flows(self, config, version):
# Stop processing for a flow within a workspace
async def stop_flow(self, workspace, flow):
key = (workspace, flow)
if key in self.flows:
await self.flows[key].stop()
del self.flows[key]
logger.info(f"Stopped flow: {workspace}/{flow}")
logger.info(f"Got config version {version}")
# Event handler - called for a configuration change for a single
# workspace
async def on_configure_flows(self, workspace, config, version):
logger.info(
f"Got config version {version} for workspace {workspace}"
)
config_type = f"processor:{self.id}"
@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor):
for k, v in config[config_type].items()
}
else:
logger.debug("No configuration settings for me.")
logger.debug(
f"No configuration settings for me in {workspace}."
)
flow_config = {}
# Get list of flows which should be running and are currently
# running
wanted_flows = flow_config.keys()
# This takes a copy, needed because dict gets modified by stop_flow
current_flows = list(self.flows.keys())
# Get list of flows which should be running in this workspace,
# and the list currently running in this workspace
wanted_flows = set(flow_config.keys())
current_flows = {
f for (ws, f) in self.flows.keys() if ws == workspace
}
# Start all the flows which arent currently running
for flow in wanted_flows:
if flow not in current_flows:
await self.start_flow(flow, flow_config[flow])
# Start all the flows which aren't currently running in this
# workspace
for flow in wanted_flows - current_flows:
await self.start_flow(workspace, flow, flow_config[flow])
# Stop all the unwanted flows which are due to be stopped
for flow in current_flows:
if flow not in wanted_flows:
await self.stop_flow(flow)
# Stop all the unwanted flows in this workspace
for flow in current_flows - wanted_flows:
await self.stop_flow(workspace, flow)
logger.info("Handled config update")
logger.info(f"Handled config update for workspace {workspace}")
# Start threads, just call parent
async def start(self):

View file

@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor):
logger.debug(f"Handling graph embeddings query request {id}...")
entities = await self.query_graph_embeddings(request)
entities = await self.query_graph_embeddings(
flow.workspace, request,
)
logger.debug("Sending graph embeddings query response...")
r = GraphEmbeddingsResponse(entities=entities, error=None)

View file

@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor):
request = msg.value()
await self.store_graph_embeddings(request)
# Workspace comes from the flow the message arrived on.
await self.store_graph_embeddings(flow.workspace, request)
except TooManyRequests as e:
raise e

View file

@ -150,7 +150,7 @@ class LibrarianClient:
finally:
self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, user, timeout=120):
async def fetch_document_content(self, document_id, workspace, timeout=120):
"""Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding.
@ -158,7 +158,7 @@ class LibrarianClient:
req = LibrarianRequest(
operation="stream-document",
document_id=document_id,
user=user,
workspace=workspace,
)
chunks = await self.stream(req, timeout=timeout)
@ -176,24 +176,24 @@ class LibrarianClient:
return base64.b64encode(raw)
async def fetch_document_text(self, document_id, user, timeout=120):
async def fetch_document_text(self, document_id, workspace, timeout=120):
"""Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content(
document_id, user, timeout=timeout,
document_id, workspace, timeout=timeout,
)
return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, user, timeout=120):
async def fetch_document_metadata(self, document_id, workspace, timeout=120):
"""Fetch document metadata from the librarian."""
req = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
workspace=workspace,
)
response = await self.request(req, timeout=timeout)
return response.document_metadata
async def save_child_document(self, doc_id, parent_id, user, content,
async def save_child_document(self, doc_id, parent_id, workspace, content,
document_type="chunk", title=None,
kind="text/plain", timeout=120):
"""Save a child document to the librarian."""
@ -202,7 +202,7 @@ class LibrarianClient:
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
@ -218,7 +218,7 @@ class LibrarianClient:
await self.request(req, timeout=timeout)
return doc_id
async def save_document(self, doc_id, user, content, title=None,
async def save_document(self, doc_id, workspace, content, title=None,
document_type="answer", kind="text/plain",
timeout=120):
"""Save a document to the librarian."""
@ -227,7 +227,7 @@ class LibrarianClient:
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind=kind,
title=title or doc_id,
document_type=document_type,
@ -238,7 +238,7 @@ class LibrarianClient:
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
user=user,
workspace=workspace,
)
await self.request(req, timeout=timeout)

View file

@ -133,8 +133,9 @@ class RequestResponseSpec(Spec):
# Make subscription names unique, so that all subscribers get
# to see all response messages
subscription = (
processor.id + "--" + flow.name + "--" + self.request_name +
"--" + str(uuid.uuid4())
processor.id + "--" + flow.workspace + "--" +
flow.name + "--" + self.request_name + "--" +
str(uuid.uuid4())
),
consumer_name = flow.id,
request_topic = definition["topics"][self.request_name],

View file

@ -21,7 +21,7 @@ class SubscriberSpec(Spec):
subscriber = Subscriber(
backend = processor.pubsub,
topic = definition["topics"][self.name],
subscription = flow.id,
subscription = flow.id + "--" + flow.workspace + "--" + flow.name,
consumer_name = flow.id,
schema = self.schema,
metrics = subscriber_metrics,

View file

@ -64,6 +64,7 @@ class ToolService(FlowProcessor):
id = msg.properties()["id"]
response = await self.invoke_tool(
flow.workspace,
request.name,
json.loads(request.parameters) if request.parameters else {},
)

View file

@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor):
logger.debug(f"Handling triples query request {id}...")
workspace = flow.workspace
if request.streaming:
# Streaming mode: send batches
async for batch, is_final in self.query_triples_stream(request):
async for batch, is_final in self.query_triples_stream(
workspace, request,
):
r = TriplesQueryResponse(
triples=batch,
error=None,
@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor):
logger.debug("Triples query streaming completed")
else:
# Non-streaming mode: single response
triples = await self.query_triples(request)
triples = await self.query_triples(workspace, request)
logger.debug("Sending triples query response...")
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor):
await flow("response").send(r, properties={"id": id})
async def query_triples_stream(self, request):
async def query_triples_stream(self, workspace, request):
"""
Streaming query - yields (batch, is_final) tuples.
Default implementation batches results from query_triples.
Override for true streaming from backend.
"""
triples = await self.query_triples(request)
triples = await self.query_triples(workspace, request)
batch_size = request.batch_size if request.batch_size > 0 else 20
for i in range(0, len(triples), batch_size):

View file

@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor):
request = msg.value()
await self.store_triples(request)
# Workspace is derived from the flow the message arrived on,
# not from fields in the message payload. Topic routing is
# the isolation boundary.
await self.store_triples(flow.workspace, request)
except TooManyRequests as e:
raise e

View file

@ -33,6 +33,7 @@ class ConfigClient(BaseClient):
subscriber=None,
input_queue=None,
output_queue=None,
workspace="default",
**pubsub_config,
):
@ -51,10 +52,13 @@ class ConfigClient(BaseClient):
**pubsub_config,
)
self.workspace = workspace
def get(self, keys, timeout=300):
resp = self.call(
operation="get",
workspace=self.workspace,
keys=[
ConfigKey(
type = k["type"],
@ -78,6 +82,7 @@ class ConfigClient(BaseClient):
resp = self.call(
operation="list",
workspace=self.workspace,
type=type,
timeout=timeout
)
@ -88,6 +93,7 @@ class ConfigClient(BaseClient):
resp = self.call(
operation="getvalues",
workspace=self.workspace,
type=type,
timeout=timeout
)
@ -101,10 +107,31 @@ class ConfigClient(BaseClient):
for v in resp.values
]
def getvalues_all_ws(self, type, timeout=300):
"""Fetch all values of a given type across all workspaces.
Returns a list of dicts including a 'workspace' field."""
resp = self.call(
operation="getvalues-all-ws",
type=type,
timeout=timeout
)
return [
{
"workspace": v.workspace,
"type": v.type,
"key": v.key,
"value": v.value,
}
for v in resp.values
]
def delete(self, keys, timeout=300):
resp = self.call(
operation="delete",
workspace=self.workspace,
keys=[
ConfigKey(
type = k["type"],
@ -121,6 +148,7 @@ class ConfigClient(BaseClient):
resp = self.call(
operation="put",
workspace=self.workspace,
values=[
ConfigValue(
type = v["type"],
@ -138,6 +166,7 @@ class ConfigClient(BaseClient):
resp = self.call(
operation="config",
workspace=self.workspace,
timeout=timeout
)

View file

@ -9,7 +9,7 @@ class CollectionManagementRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest:
return CollectionManagementRequest(
operation=data.get("operation"),
user=data.get("user"),
workspace=data.get("workspace", ""),
collection=data.get("collection"),
timestamp=data.get("timestamp"),
name=data.get("name"),
@ -24,8 +24,8 @@ class CollectionManagementRequestTranslator(MessageTranslator):
if obj.operation is not None:
result["operation"] = obj.operation
if obj.user is not None:
result["user"] = obj.user
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection is not None:
result["collection"] = obj.collection
if obj.timestamp is not None:
@ -63,7 +63,6 @@ class CollectionManagementResponseTranslator(MessageTranslator):
if "collections" in data:
for coll_data in data["collections"]:
collections.append(CollectionMetadata(
user=coll_data.get("user"),
collection=coll_data.get("collection"),
name=coll_data.get("name"),
description=coll_data.get("description"),
@ -91,7 +90,6 @@ class CollectionManagementResponseTranslator(MessageTranslator):
result["collections"] = []
for coll in obj.collections:
result["collections"].append({
"user": coll.user,
"collection": coll.collection,
"name": coll.name,
"description": coll.description,

View file

@ -23,13 +23,15 @@ class ConfigRequestTranslator(MessageTranslator):
ConfigValue(
type=v["type"],
key=v["key"],
value=v["value"]
value=v["value"],
workspace=v.get("workspace", ""),
)
for v in data["values"]
]
return ConfigRequest(
operation=data.get("operation"),
workspace=data.get("workspace", ""),
keys=keys,
type=data.get("type"),
values=values
@ -37,10 +39,13 @@ class ConfigRequestTranslator(MessageTranslator):
def encode(self, obj: ConfigRequest) -> Dict[str, Any]:
result = {}
if obj.operation is not None:
result["operation"] = obj.operation
if obj.workspace is not None:
result["workspace"] = obj.workspace
if obj.type is not None:
result["type"] = obj.type
@ -56,13 +61,14 @@ class ConfigRequestTranslator(MessageTranslator):
if obj.values is not None:
result["values"] = [
{
**({"workspace": v.workspace} if v.workspace else {}),
"type": v.type,
"key": v.key,
"value": v.value
"value": v.value,
}
for v in obj.values
]
return result
@ -81,13 +87,14 @@ class ConfigResponseTranslator(MessageTranslator):
if obj.values is not None:
result["values"] = [
{
**({"workspace": v.workspace} if v.workspace else {}),
"type": v.type,
"key": v.key,
"value": v.value
"value": v.value,
}
for v in obj.values
]
if obj.directory is not None:
result["directory"] = obj.directory

View file

@ -39,7 +39,6 @@ class DocumentTranslator(SendTranslator):
metadata=Metadata(
id=data.get("id"),
root=data.get("root", ""),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
@ -56,8 +55,6 @@ class DocumentTranslator(SendTranslator):
metadata_dict["id"] = obj.metadata.id
if obj.metadata.root:
metadata_dict["root"] = obj.metadata.root
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
@ -79,7 +76,6 @@ class TextDocumentTranslator(SendTranslator):
metadata=Metadata(
id=data.get("id"),
root=data.get("root", ""),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text.encode("utf-8")
@ -96,8 +92,6 @@ class TextDocumentTranslator(SendTranslator):
metadata_dict["id"] = obj.metadata.id
if obj.metadata.root:
metadata_dict["root"] = obj.metadata.root
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
@ -115,7 +109,6 @@ class ChunkTranslator(SendTranslator):
metadata=Metadata(
id=data.get("id"),
root=data.get("root", ""),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"]
@ -132,8 +125,6 @@ class ChunkTranslator(SendTranslator):
metadata_dict["id"] = obj.metadata.id
if obj.metadata.root:
metadata_dict["root"] = obj.metadata.root
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
@ -161,7 +152,6 @@ class DocumentEmbeddingsTranslator(SendTranslator):
metadata=Metadata(
id=metadata.get("id"),
root=metadata.get("root", ""),
user=metadata.get("user", "trustgraph"),
collection=metadata.get("collection", "default"),
),
chunks=chunks
@ -184,8 +174,6 @@ class DocumentEmbeddingsTranslator(SendTranslator):
metadata_dict["id"] = obj.metadata.id
if obj.metadata.root:
metadata_dict["root"] = obj.metadata.root
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection

View file

@ -9,18 +9,21 @@ class FlowRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> FlowRequest:
return FlowRequest(
operation=data.get("operation"),
workspace=data.get("workspace", ""),
blueprint_name=data.get("blueprint-name"),
blueprint_definition=data.get("blueprint-definition"),
description=data.get("description"),
flow_id=data.get("flow-id"),
parameters=data.get("parameters")
)
def encode(self, obj: FlowRequest) -> Dict[str, Any]:
result = {}
if obj.operation is not None:
result["operation"] = obj.operation
if obj.workspace is not None:
result["workspace"] = obj.workspace
if obj.blueprint_name is not None:
result["blueprint-name"] = obj.blueprint_name
if obj.blueprint_definition is not None:

View file

@ -21,7 +21,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
metadata=Metadata(
id=data["triples"]["metadata"]["id"],
root=data["triples"]["metadata"].get("root", ""),
user=data["triples"]["metadata"]["user"],
collection=data["triples"]["metadata"]["collection"]
),
triples=self.subgraph_translator.decode(data["triples"]["triples"]),
@ -33,7 +32,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
metadata=Metadata(
id=data["graph-embeddings"]["metadata"]["id"],
root=data["graph-embeddings"]["metadata"].get("root", ""),
user=data["graph-embeddings"]["metadata"]["user"],
collection=data["graph-embeddings"]["metadata"]["collection"]
),
entities=[
@ -47,7 +45,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
return KnowledgeRequest(
operation=data.get("operation"),
user=data.get("user"),
workspace=data.get("workspace", ""),
id=data.get("id"),
flow=data.get("flow"),
collection=data.get("collection"),
@ -60,8 +58,8 @@ class KnowledgeRequestTranslator(MessageTranslator):
if obj.operation:
result["operation"] = obj.operation
if obj.user:
result["user"] = obj.user
if obj.workspace:
result["workspace"] = obj.workspace
if obj.id:
result["id"] = obj.id
if obj.flow:
@ -74,7 +72,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
"metadata": {
"id": obj.triples.metadata.id,
"root": obj.triples.metadata.root,
"user": obj.triples.metadata.user,
"collection": obj.triples.metadata.collection,
},
"triples": self.subgraph_translator.encode(obj.triples.triples),
@ -85,7 +82,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
"metadata": {
"id": obj.graph_embeddings.metadata.id,
"root": obj.graph_embeddings.metadata.root,
"user": obj.graph_embeddings.metadata.user,
"collection": obj.graph_embeddings.metadata.collection,
},
"entities": [
@ -122,7 +118,6 @@ class KnowledgeResponseTranslator(MessageTranslator):
"metadata": {
"id": obj.triples.metadata.id,
"root": obj.triples.metadata.root,
"user": obj.triples.metadata.user,
"collection": obj.triples.metadata.collection,
},
"triples": self.subgraph_translator.encode(obj.triples.triples),
@ -136,7 +131,6 @@ class KnowledgeResponseTranslator(MessageTranslator):
"metadata": {
"id": obj.graph_embeddings.metadata.id,
"root": obj.graph_embeddings.metadata.root,
"user": obj.graph_embeddings.metadata.user,
"collection": obj.graph_embeddings.metadata.collection,
},
"entities": [

View file

@ -49,7 +49,7 @@ class LibraryRequestTranslator(MessageTranslator):
document_metadata=doc_metadata,
processing_metadata=proc_metadata,
content=content,
user=data.get("user", ""),
workspace=data.get("workspace", ""),
collection=data.get("collection", ""),
criteria=criteria,
# Chunked upload fields
@ -76,8 +76,8 @@ class LibraryRequestTranslator(MessageTranslator):
result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata)
if obj.content:
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
if obj.user:
result["user"] = obj.user
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection:
result["collection"] = obj.collection
if obj.criteria is not None:

View file

@ -19,7 +19,7 @@ class DocumentMetadataTranslator(Translator):
title=data.get("title"),
comments=data.get("comments"),
metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [],
user=data.get("user"),
workspace=data.get("workspace"),
tags=data.get("tags"),
parent_id=data.get("parent-id", ""),
document_type=data.get("document-type", "source"),
@ -40,8 +40,8 @@ class DocumentMetadataTranslator(Translator):
result["comments"] = obj.comments
if obj.metadata is not None:
result["metadata"] = self.subgraph_translator.encode(obj.metadata)
if obj.user:
result["user"] = obj.user
if obj.workspace:
result["workspace"] = obj.workspace
if obj.tags is not None:
result["tags"] = obj.tags
if obj.parent_id:
@ -61,7 +61,7 @@ class ProcessingMetadataTranslator(Translator):
document_id=data.get("document-id"),
time=data.get("time"),
flow=data.get("flow"),
user=data.get("user"),
workspace=data.get("workspace"),
collection=data.get("collection"),
tags=data.get("tags")
)
@ -77,8 +77,8 @@ class ProcessingMetadataTranslator(Translator):
result["time"] = obj.time
if obj.flow:
result["flow"] = obj.flow
if obj.user:
result["user"] = obj.user
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection:
result["collection"] = obj.collection
if obj.tags is not None:

View file

@ -8,6 +8,7 @@ class Metadata:
# Root document identifier (set by librarian, preserved through pipeline)
root: str = ""
# Collection management
user: str = ""
# Collection the message belongs to. Workspace is NOT carried on the
# message — consumers derive it from flow.workspace (the flow the
# message arrived on), which is the trusted isolation boundary.
collection: str = ""

View file

@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings
# <- (error)
# list-kg-cores
# -> (user)
# -> (workspace)
# <- ()
# <- (error)
@ -27,8 +27,8 @@ class KnowledgeRequest:
# load-kg-core, unload-kg-core
operation: str = ""
# list-kg-cores, delete-kg-core, put-kg-core
user: str = ""
# Workspace the cores belong to. Partition / isolation boundary.
workspace: str = ""
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
# load-kg-core, unload-kg-core

View file

@ -13,7 +13,6 @@ from ..core.topic import queue
@dataclass
class CollectionMetadata:
"""Collection metadata record"""
user: str = ""
collection: str = ""
name: str = ""
description: str = ""
@ -23,11 +22,17 @@ class CollectionMetadata:
@dataclass
class CollectionManagementRequest:
"""Request for collection management operations"""
"""Request for collection management operations.
Collection-management is a global (non-flow-scoped) service, so the
workspace has to travel on the wire it's the isolation boundary
for which workspace's collections the request operates on.
"""
operation: str = "" # e.g., "delete-collection"
# For 'list-collections'
user: str = ""
# Workspace the collection belongs to.
workspace: str = ""
collection: str = ""
timestamp: str = "" # ISO timestamp
name: str = ""

View file

@ -7,12 +7,19 @@ from ..core.primitives import Error
############################################################################
# Config service:
# get(keys) -> (version, values)
# list(type) -> (version, values)
# getvalues(type) -> (version, values)
# put(values) -> ()
# delete(keys) -> ()
# config() -> (version, config)
# get(workspace, keys) -> (version, values)
# list(workspace, type) -> (version, directory)
# getvalues(workspace, type) -> (version, values)
# getvalues-all-ws(type) -> (version, values with workspace field)
# put(workspace, values) -> ()
# delete(workspace, keys) -> ()
# config(workspace) -> (version, config)
#
# Most operations are scoped to a workspace. The workspace field on the
# request identifies which workspace's config to read or modify.
# getvalues-all-ws returns values across all workspaces for a single
# type — used by shared processors to load type-scoped config at startup.
@dataclass
class ConfigKey:
type: str = ""
@ -23,16 +30,24 @@ class ConfigValue:
type: str = ""
key: str = ""
value: str = ""
# Populated by getvalues-all-ws responses so callers can identify
# which workspace each value belongs to. Empty otherwise.
workspace: str = ""
# Prompt services, abstract the prompt generation
@dataclass
class ConfigRequest:
operation: str = "" # get, list, getvalues, delete, put, config
# Operations: get, list, getvalues, getvalues-all-ws, delete, put,
# config
operation: str = ""
# Workspace scope — required on all operations except
# getvalues-all-ws which spans all workspaces.
workspace: str = ""
# get, delete
keys: list[ConfigKey] = field(default_factory=list)
# list, getvalues
# list, getvalues, getvalues-all-ws
type: str = ""
# put
@ -58,7 +73,12 @@ class ConfigResponse:
@dataclass
class ConfigPush:
version: int = 0
types: list[str] = field(default_factory=list)
# Dict of config type -> list of affected workspaces.
# Handlers look up their registered type and get the list of
# workspaces that need refreshing.
# e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]}
changes: dict[str, list[str]] = field(default_factory=dict)
config_request_queue = queue('config', cls='request')
config_response_queue = queue('config', cls='response')

View file

@ -17,12 +17,14 @@ from ..core.primitives import Error
# start_flow(flowid, blueprintname) -> ()
# stop_flow(flowid) -> ()
# Prompt services, abstract the prompt generation
@dataclass
class FlowRequest:
operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint
# list-flows, get-flow, start-flow, stop-flow
# Workspace scope — all operations act within this workspace
workspace: str = ""
# get_blueprint, put_blueprint, delete_blueprint, start_flow
blueprint_name: str = ""

View file

@ -43,12 +43,12 @@ from ..core.metadata import Metadata
# <- (error)
# list-documents
# -> (user, collection?)
# -> (workspace, collection?)
# <- (document_metadata[])
# <- (error)
# list-processing
# -> (user, collection?)
# -> (workspace, collection?)
# <- (processing_metadata[])
# <- (error)
@ -78,7 +78,7 @@ from ..core.metadata import Metadata
# <- (error)
# list-uploads
# -> (user)
# -> (workspace)
# <- (uploads[])
# <- (error)
@ -90,7 +90,7 @@ class DocumentMetadata:
title: str = ""
comments: str = ""
metadata: list[Triple] = field(default_factory=list)
user: str = ""
workspace: str = ""
tags: list[str] = field(default_factory=list)
# Child document support
parent_id: str = "" # Empty for top-level docs, set for children
@ -107,7 +107,7 @@ class ProcessingMetadata:
document_id: str = ""
time: int = 0
flow: str = ""
user: str = ""
workspace: str = ""
collection: str = ""
tags: list[str] = field(default_factory=list)
@ -162,8 +162,8 @@ class LibrarianRequest:
# add-document, upload-chunk
content: bytes = b""
# list-documents, list-processing, list-uploads
user: str = ""
# Workspace scopes every library operation.
workspace: str = ""
# list-documents?, list-processing?
collection: str = ""

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.3,<2.4",
"trustgraph-base>=2.4,<2.5",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.3,<2.4",
"trustgraph-base>=2.4,<2.5",
"requests",
"pulsar-client",
"aiohttp",
@ -95,6 +95,8 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main"
tg-get-config-item = "trustgraph.cli.get_config_item:main"
tg-put-config-item = "trustgraph.cli.put_config_item:main"
tg-delete-config-item = "trustgraph.cli.delete_config_item:main"
tg-export-workspace-config = "trustgraph.cli.export_workspace_config:main"
tg-import-workspace-config = "trustgraph.cli.import_workspace_config:main"
tg-list-collections = "trustgraph.cli.list_collections:main"
tg-set-collection = "trustgraph.cli.set_collection:main"
tg-delete-collection = "trustgraph.cli.delete_collection:main"

View file

@ -15,17 +15,17 @@ from trustgraph.knowledge import Organization, PublicationEvent
from trustgraph.knowledge import DigitalDocument
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
class Loader:
def __init__(
self, id, url, user, metadata, title, comments, kind, tags
):
self, id, url, metadata, title, comments, kind, tags
, token=None, workspace="default"):
self.api = Api(url).library()
self.api = Api(url, token=token, workspace=workspace).library()
self.user = user
self.metadata = metadata
self.title = title
self.comments = comments
@ -55,13 +55,13 @@ class Loader:
else:
id = hash(data)
id = to_uri(PREF_DOC, id)
self.metadata.id = id
self.api.add_document(
document=data, id=id, metadata=self.metadata,
user=self.user, kind=self.kind, title=self.title,
document=data, id=id, metadata=self.metadata,
kind=self.kind, title=self.title,
comments=self.comments, tags=self.tags
)
@ -83,11 +83,16 @@ def main():
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
@ -186,12 +191,13 @@ def main():
p = Loader(
id=args.identifier,
url=args.url,
user=args.user,
metadata=document,
title=args.name,
comments=args.description,
kind=args.kind,
tags=args.tags,
token=args.token,
workspace=args.workspace,
)
p.load(args.files)

View file

@ -7,9 +7,11 @@ import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = "trustgraph"
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_collection(url, user, collection, confirm):
def delete_collection(url, collection, confirm, token=None, workspace="default"):
if not confirm:
response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ")
@ -17,9 +19,9 @@ def delete_collection(url, user, collection, confirm):
print("Operation cancelled.")
return
api = Api(url).collection()
api = Api(url, token=token, workspace=workspace).collection()
api.delete_collection(user=user, collection=collection)
api.delete_collection(collection=collection)
print(f"Collection '{collection}' deleted successfully.")
@ -41,27 +43,34 @@ def main():
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-y', '--yes',
action='store_true',
help='Skip confirmation prompt'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
delete_collection(
url = args.api_url,
user = args.user,
collection = args.collection,
confirm = args.yes
confirm = args.yes,
token = args.token,
workspace = args.workspace,
)
except Exception as e:
@ -69,4 +78,4 @@ def main():
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()
main()

View file

@ -9,10 +9,11 @@ from trustgraph.api.types import ConfigKey
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_config_item(url, config_type, key, token=None):
def delete_config_item(url, config_type, key, token=None, workspace="default"):
api = Api(url, token=token).config()
api = Api(url, token=token, workspace=workspace).config()
config_key = ConfigKey(type=config_type, key=key)
api.delete([config_key])
@ -50,6 +51,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
@ -59,6 +66,8 @@ def main():
config_type=args.type,
key=args.key,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,10 +9,13 @@ from trustgraph.api import Api
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_flow_blueprint(url, blueprint_name):
def delete_flow_blueprint(url, blueprint_name, token=None,
workspace="default"):
api = Api(url).flow()
api = Api(url, token=token, workspace=workspace).flow()
blueprint_names = api.delete_blueprint(blueprint_name)
@ -29,6 +32,18 @@ def main():
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-n', '--blueprint-name',
help=f'Flow blueprint name',
@ -41,6 +56,8 @@ def main():
delete_flow_blueprint(
url=args.api_url,
blueprint_name=args.blueprint_name,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -1,20 +1,20 @@
"""
Deletes a flow class
Deletes a knowledge core
"""
import argparse
import os
import tabulate
from trustgraph.api import Api
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_kg_core(url, user, id):
def delete_kg_core(url, id, token=None, workspace="default"):
api = Api(url).knowledge()
api = Api(url, token=token, workspace=workspace).knowledge()
class_names = api.delete_kg_core(user = user, id = id)
api.delete_kg_core(id=id)
def main():
@ -29,26 +29,33 @@ def main():
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default="trustgraph",
help='API URL (default: trustgraph)',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Knowledge core ID',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
delete_kg_core(
url=args.api_url,
user=args.user,
id=args.id,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -10,12 +10,16 @@ import textwrap
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_mcp_tool(
url : str,
id : str,
token=None,
workspace="default",
):
api = Api(url).config()
api = Api(url, token=token, workspace=workspace).config()
# Check if the tool exists first
try:
@ -73,6 +77,18 @@ def main():
help='MCP tool ID to delete',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
@ -81,8 +97,10 @@ def main():
raise RuntimeError("Must specify --id for MCP tool to delete")
delete_mcp_tool(
url=args.api_url,
id=args.id
url=args.api_url,
id=args.id,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -12,12 +12,16 @@ import textwrap
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def delete_tool(
url : str,
id : str,
token=None,
workspace="default",
):
api = Api(url).config()
api = Api(url, token=token, workspace=workspace).config()
# Check if the tool configuration exists
try:
@ -78,6 +82,18 @@ def main():
help='Tool ID to delete',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
@ -86,8 +102,10 @@ def main():
raise RuntimeError("Must specify --id for tool to delete")
delete_tool(
url=args.api_url,
id=args.id
url=args.api_url,
id=args.id,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -0,0 +1,114 @@
"""
Exports a curated subset of a workspace's configuration to a JSON file
for later reload into another workspace (useful for cloning test setups).
The subset covers the config types that define workspace behaviour:
mcp-tool, tool, flow-blueprint, token-cost, agent-pattern,
agent-task-type, parameter-type, interface-description, prompt.
"""
import argparse
import os
import json
import sys
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
EXPORT_TYPES = [
"mcp-tool",
"tool",
"flow-blueprint",
"token-cost",
"agent-pattern",
"agent-task-type",
"parameter-type",
"interface-description",
"prompt",
]
def export_workspace_config(url, workspace, output, token=None):
api = Api(url, token=token, workspace=workspace).config()
config, version = api.all()
subset = {}
for t in EXPORT_TYPES:
if t in config:
subset[t] = config[t]
payload = {
"source_workspace": workspace,
"source_version": version,
"config": subset,
}
if output == "-":
json.dump(payload, sys.stdout, indent=2)
sys.stdout.write("\n")
else:
with open(output, "w") as f:
json.dump(payload, f, indent=2)
total = sum(len(v) for v in subset.values())
print(
f"Exported {total} items across {len(subset)} types "
f"from workspace '{workspace}' (version {version}).",
file=sys.stderr,
)
def main():
parser = argparse.ArgumentParser(
prog='tg-export-workspace-config',
description=__doc__,
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Source workspace (default: {default_workspace})',
)
parser.add_argument(
'-o', '--output',
required=True,
help='Output JSON file path (use "-" for stdout)',
)
args = parser.parse_args()
try:
export_workspace_config(
url=args.api_url,
workspace=args.workspace,
output=args.output,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigKey
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def get_config_item(url, config_type, key, format_type, token=None):
def get_config_item(url, config_type, key, format_type, token=None,
workspace="default"):
api = Api(url, token=token).config()
api = Api(url, token=token, workspace=workspace).config()
config_key = ConfigKey(type=config_type, key=key)
values = api.get([config_key])
@ -66,6 +68,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
args = parser.parse_args()
try:
@ -76,6 +84,7 @@ def main():
key=args.key,
format_type=args.format,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,21 +9,19 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = "trustgraph"
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def get_content(url, user, document_id, output_file, token=None):
def get_content(url, document_id, output_file, token=None, workspace="default"):
api = Api(url, token=token).library()
api = Api(url, token=token, workspace=workspace).library()
content = api.get_document_content(user=user, id=document_id)
content = api.get_document_content(id=document_id)
if output_file:
with open(output_file, 'wb') as f:
f.write(content)
print(f"Written {len(content)} bytes to {output_file}")
else:
# Write to stdout
# Try to decode as text, fall back to binary info
try:
text = content.decode('utf-8')
print(text)
@ -51,9 +49,9 @@ def main():
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
@ -73,10 +71,10 @@ def main():
get_content(
url=args.api_url,
user=args.user,
document_id=args.document_id,
output_file=args.output,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,10 +9,12 @@ from trustgraph.api import Api
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def get_flow_blueprint(url, blueprint_name):
def get_flow_blueprint(url, blueprint_name, token=None, workspace="default"):
api = Api(url).flow()
api = Api(url, token=token, workspace=workspace).flow()
cls = api.get_blueprint(blueprint_name)
@ -31,6 +33,18 @@ def main():
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-n', '--blueprint-name',
required=True,
@ -44,6 +58,8 @@ def main():
get_flow_blueprint(
url=args.api_url,
blueprint_name=args.blueprint_name,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -5,7 +5,6 @@ to a local file in msgpack format.
import argparse
import os
import textwrap
import uuid
import asyncio
import json
@ -13,17 +12,16 @@ from websockets.asyncio.client import connect
import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def write_triple(f, data):
msg = (
"t",
{
"m": {
"i": data["metadata"]["id"],
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
@ -36,9 +34,8 @@ def write_ge(f, data):
"ge",
{
"m": {
"i": data["metadata"]["id"],
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"e": [
@ -52,7 +49,7 @@ def write_ge(f, data):
)
f.write(msgpack.packb(msg, use_bin_type=True))
async def fetch(url, user, id, output, token=None):
async def fetch(url, workspace, id, output, token=None):
if not url.endswith("/"):
url += "/"
@ -68,10 +65,11 @@ async def fetch(url, user, id, output, token=None):
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "get-kg-core",
"user": user,
"workspace": workspace,
"id": id,
}
})
@ -124,10 +122,11 @@ def main():
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
@ -154,11 +153,11 @@ def main():
asyncio.run(
fetch(
url = args.url,
user = args.user,
id = args.id,
output = args.output,
token = args.token,
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
)
@ -167,4 +166,4 @@ def main():
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()
main()

View file

@ -13,9 +13,9 @@ import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_collection = 'default'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def term_to_rdflib(term):
@ -58,9 +58,10 @@ def term_to_rdflib(term):
return rdflib.term.Literal(str(term))
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
def show_graph(url, flow_id, collection, limit, batch_size,
token=None, workspace="default"):
socket = Api(url, token=token).socket()
socket = Api(url, token=token, workspace=workspace).socket()
flow = socket.flow(flow_id)
g = rdflib.Graph()
@ -68,7 +69,7 @@ def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
try:
for batch in flow.triples_query_stream(
s=None, p=None, o=None,
user=user, collection=collection,
collection=collection,
limit=limit,
batch_size=batch_size,
):
@ -108,12 +109,6 @@ def main():
help=f'Flow ID (default: default)'
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -126,6 +121,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-l', '--limit',
type=int,
@ -147,11 +148,11 @@ def main():
show_graph(
url = args.api_url,
flow_id = args.flow_id,
user = args.user,
collection = args.collection,
limit = args.limit,
batch_size = args.batch_size,
token = args.token,
workspace = args.workspace,
)
except Exception as e:

View file

@ -0,0 +1,143 @@
"""
Imports a workspace-config dump produced by tg-export-workspace-config
into a target workspace. Writes mcp-tool, tool, flow-blueprint,
token-cost, agent-pattern, agent-task-type, parameter-type,
interface-description and prompt items verbatim.
Existing items with the same (type, key) are overwritten.
"""
import argparse
import os
import json
import sys
from trustgraph.api import Api
from trustgraph.api.types import ConfigValue
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
IMPORT_TYPES = {
"mcp-tool",
"tool",
"flow-blueprint",
"token-cost",
"agent-pattern",
"agent-task-type",
"parameter-type",
"interface-description",
"prompt",
}
def import_workspace_config(url, workspace, input_path, token=None,
dry_run=False):
if input_path == "-":
payload = json.load(sys.stdin)
else:
with open(input_path, "r") as f:
payload = json.load(f)
# Accept both the wrapped export format and a bare {type: {key: value}}
# dict, so hand-written files are also loadable.
if isinstance(payload, dict) and "config" in payload \
and isinstance(payload["config"], dict):
config = payload["config"]
source = payload.get("source_workspace")
else:
config = payload
source = None
skipped_types = set(config.keys()) - IMPORT_TYPES
if skipped_types:
print(
f"Ignoring unsupported types: {sorted(skipped_types)}",
file=sys.stderr,
)
values = []
for t in IMPORT_TYPES:
items = config.get(t, {})
for key, value in items.items():
values.append(ConfigValue(type=t, key=key, value=value))
if not values:
print("Nothing to import.", file=sys.stderr)
return
if dry_run:
print(
f"[dry-run] would import {len(values)} items into "
f"workspace '{workspace}'"
+ (f" (from '{source}')" if source else "")
)
return
api = Api(url, token=token, workspace=workspace).config()
api.put(values)
print(
f"Imported {len(values)} items into workspace '{workspace}'"
+ (f" (from '{source}')." if source else "."),
)
def main():
parser = argparse.ArgumentParser(
prog='tg-import-workspace-config',
description=__doc__,
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Target workspace (default: {default_workspace})',
)
parser.add_argument(
'-i', '--input',
required=True,
help='Input JSON file path (use "-" for stdin)',
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Parse and validate the input without writing anything',
)
args = parser.parse_args()
try:
import_workspace_config(
url=args.api_url,
workspace=args.workspace,
input_path=args.input,
token=args.token,
dry_run=args.dry_run,
)
except Exception as e:
print("Exception:", e, flush=True)
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -69,10 +69,11 @@ def ensure_namespace(url, tenant, namespace, config):
print(f"Namespace {tenant}/{namespace} created.", flush=True)
def ensure_config(config, **pubsub_config):
def ensure_config(config, workspace="default", **pubsub_config):
cli = ConfigClient(
subscriber=subscriber,
workspace=workspace,
**pubsub_config,
)
@ -147,7 +148,8 @@ def init_pulsar(pulsar_admin_url, tenant):
})
def push_config(config_json, config_file, **pubsub_config):
def push_config(config_json, config_file, workspace="default",
**pubsub_config):
"""Push initial config if provided."""
if config_json is not None:
@ -160,7 +162,7 @@ def push_config(config_json, config_file, **pubsub_config):
print("Exception:", e, flush=True)
raise e
ensure_config(dec, **pubsub_config)
ensure_config(dec, workspace=workspace, **pubsub_config)
elif config_file is not None:
@ -172,7 +174,7 @@ def push_config(config_json, config_file, **pubsub_config):
print("Exception:", e, flush=True)
raise e
ensure_config(dec, **pubsub_config)
ensure_config(dec, workspace=workspace, **pubsub_config)
else:
print("No config to update.", flush=True)
@ -207,6 +209,12 @@ def main():
help=f'Tenant (default: tg)',
)
parser.add_argument(
'-w', '--workspace',
default="default",
help=f'Workspace (default: default)',
)
add_pubsub_args(parser)
args = parser.parse_args()
@ -216,7 +224,10 @@ def main():
# Extract pubsub config from args
pubsub_config = {
k: v for k, v in vars(args).items()
if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant')
if k not in (
'pulsar_admin_url', 'config', 'config_file', 'tenant',
'workspace',
)
}
while True:
@ -241,6 +252,7 @@ def main():
# Push config (works with any backend)
push_config(
args.config, args.config_file,
workspace=args.workspace,
**pubsub_config,
)

View file

@ -26,7 +26,7 @@ from trustgraph.api import (
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
class Outputter:
@ -119,7 +119,7 @@ def question_explainable(
state=None, group=None, verbose=False, token=None, debug=False
):
"""Execute agent with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
@ -296,7 +296,7 @@ def question(
print()
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -418,6 +418,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -430,12 +436,6 @@ def main():
help=f'Question to answer',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -502,7 +502,7 @@ def main():
url = args.url,
flow_id = args.flow_id,
question = args.question,
user = args.user,
user = "",
collection = args.collection,
plan = args.plan,
state = args.state,

View file

@ -9,11 +9,12 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, query_text, user, collection, limit, token=None):
def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None):
# Call document embeddings query service
result = flow.document_embeddings_query(
text=query_text,
user=user,
collection=collection,
limit=limit
)
@ -59,15 +59,15 @@ def main():
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-U', '--user',
default="trustgraph",
help='User/keyspace (default: trustgraph)',
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
@ -97,10 +97,10 @@ def main():
url=args.url,
flow_id=args.flow_id,
query_text=args.query[0],
user=args.user,
collection=args.collection,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -18,16 +18,17 @@ from trustgraph.api import (
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
default_doc_limit = 10
def question_explainable(
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False
url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False,
workspace="default",
):
"""Execute document RAG with explainability - shows provenance events inline."""
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
@ -100,7 +101,7 @@ def question_explainable(
def question(
url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False,
show_usage=False
show_usage=False, workspace="default",
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
@ -112,12 +113,13 @@ def question(
collection=collection,
doc_limit=doc_limit,
token=token,
debug=debug
debug=debug,
workspace=workspace,
)
return
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
if streaming:
# Use socket client for streaming
@ -189,6 +191,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -201,12 +209,6 @@ def main():
help=f'Question to answer',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -252,7 +254,7 @@ def main():
url=args.url,
flow_id=args.flow_id,
question_text=args.question,
user=args.user,
user="",
collection=args.collection,
doc_limit=args.doc_limit,
streaming=not args.no_streaming,
@ -260,6 +262,7 @@ def main():
explainable=args.explainable,
debug=args.debug,
show_usage=args.show_usage,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,11 +9,12 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, texts, token=None):
def query(url, flow_id, texts, token=None, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -51,6 +52,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -72,6 +79,8 @@ def main():
flow_id=args.flow_id,
texts=args.texts,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,11 +9,12 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, query_text, user, collection, limit, token=None):
def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None):
# Call graph embeddings query service
result = flow.graph_embeddings_query(
text=query_text,
user=user,
collection=collection,
limit=limit
)
@ -69,15 +69,15 @@ def main():
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-U', '--user',
default="trustgraph",
help='User/keyspace (default: trustgraph)',
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
@ -107,10 +107,10 @@ def main():
url=args.url,
flow_id=args.flow_id,
query_text=args.query[0],
user=args.user,
collection=args.collection,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -22,7 +22,7 @@ from trustgraph.api import (
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_user = 'trustgraph'
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
default_entity_limit = 50
default_triple_limit = 30
@ -641,10 +641,10 @@ async def _question_explainable(
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=30,
edge_limit=25, token=None, debug=False
edge_limit=25, token=None, debug=False, workspace="default",
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10)
@ -753,7 +753,8 @@ def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
explainable=False, debug=False, show_usage=False
explainable=False, debug=False, show_usage=False,
workspace="default",
):
# Explainable mode uses the API to capture and process provenance events
@ -771,12 +772,13 @@ def question(
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
token=token,
debug=debug
debug=debug,
workspace=workspace,
)
return
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
if streaming:
# Use socket client for streaming
@ -857,6 +859,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -869,12 +877,6 @@ def main():
help=f'Question to answer',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -955,7 +957,7 @@ def main():
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
user="",
collection=args.collection,
entity_limit=args.entity_limit,
triple_limit=args.triple_limit,
@ -968,6 +970,7 @@ def main():
explainable=args.explainable,
debug=args.debug,
show_usage=args.show_usage,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,12 +9,13 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, system, prompt, streaming=True, token=None,
show_usage=False):
show_usage=False, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -74,6 +75,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'system',
nargs=1,
@ -116,6 +123,7 @@ def main():
streaming=not args.no_streaming,
token=args.token,
show_usage=args.show_usage,
workspace=args.workspace,
)
except Exception as e:

View file

@ -11,10 +11,12 @@ import json
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, name, parameters):
def query(url, flow_id, name, parameters, token=None, workspace="default"):
api = Api(url).flow().id(flow_id)
api = Api(url, token=token, workspace=workspace).flow().id(flow_id)
resp = api.mcp_tool(name=name, parameters=parameters)
@ -36,6 +38,18 @@ def main():
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -68,6 +82,8 @@ def main():
flow_id = args.flow_id,
name = args.name,
parameters = parameters,
token = args.token,
workspace = args.workspace,
)
except Exception as e:

View file

@ -10,9 +10,11 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
def nlp_query(url, flow_id, question, max_results, output_format='json'):
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def nlp_query(url, flow_id, question, max_results, output_format='json', token=None, workspace="default"):
api = Api(url).flow().id(flow_id)
api = Api(url, token=token, workspace=workspace).flow().id(flow_id)
resp = api.nlp_query(
question=question,
@ -63,6 +65,17 @@ def main():
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
@ -100,6 +113,11 @@ def main():
question=args.question,
max_results=args.max_results,
output_format=args.format,
token = args.token,
workspace = args.workspace,
)
except Exception as e:

View file

@ -14,12 +14,13 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, template_id, variables, streaming=True, token=None,
show_usage=False):
show_usage=False, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -80,6 +81,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
@ -135,6 +142,7 @@ specified multiple times''',
streaming=not args.no_streaming,
token=args.token,
show_usage=args.show_usage,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,11 +9,12 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None):
def query(url, flow_id, query_text, schema_name, collection, index_name, limit, token=None, workspace="default"):
# Create API client
api = Api(url=url, token=token)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
@ -22,7 +23,6 @@ def query(url, flow_id, query_text, schema_name, user, collection, index_name, l
result = flow.row_embeddings_query(
text=query_text,
schema_name=schema_name,
user=user,
collection=collection,
index_name=index_name,
limit=limit
@ -60,15 +60,15 @@ def main():
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-U', '--user',
default="trustgraph",
help='User/keyspace (default: trustgraph)',
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
@ -111,11 +111,11 @@ def main():
flow_id=args.flow_id,
query_text=args.query[0],
schema_name=args.schema_name,
user=args.user,
collection=args.collection,
index_name=args.index_name,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -12,10 +12,11 @@ from trustgraph.api import Api
from tabulate import tabulate
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
def format_output(data, output_format):
def format_output(data, output_format, token=None, workspace="default"):
"""Format GraphQL response data in the specified format"""
if not data:
return "No data returned"
@ -82,10 +83,10 @@ def format_table_data(rows, table_name, output_format):
return json.dumps({table_name: rows}, indent=2)
def rows_query(
url, flow_id, query, user, collection, variables, operation_name, output_format='table'
url, flow_id, query, collection, variables, operation_name, output_format='table', token=None, workspace="default"
):
api = Api(url).flow().id(flow_id)
api = Api(url, token=token, workspace=workspace).flow().id(flow_id)
# Parse variables if provided as JSON string
parsed_variables = {}
@ -98,7 +99,6 @@ def rows_query(
resp = api.rows_query(
query=query,
user=user,
collection=collection,
variables=parsed_variables if parsed_variables else None,
operation_name=operation_name
@ -135,6 +135,17 @@ def main():
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
@ -148,12 +159,6 @@ def main():
help='GraphQL query to execute',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -185,11 +190,13 @@ def main():
url=args.url,
flow_id=args.flow_id,
query=args.query,
user=args.user,
collection=args.collection,
variables=args.variables,
operation_name=args.operation_name,
output_format=args.format,
token=args.token,
workspace=args.workspace,
)
except Exception as e:

View file

@ -9,7 +9,8 @@ import sys
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
@ -44,10 +45,10 @@ def _term_str(val):
return str(val)
def sparql_query(url, token, flow_id, query, user, collection, limit,
batch_size, output_format):
def sparql_query(url, token, flow_id, query, collection, limit,
batch_size, output_format, workspace="default"):
socket = Api(url=url, token=token).socket()
socket = Api(url=url, token=token, workspace=workspace).socket()
flow = socket.flow(flow_id)
variables = None
@ -57,7 +58,6 @@ def sparql_query(url, token, flow_id, query, user, collection, limit,
for response in flow.sparql_query_stream(
query=query,
user=user,
collection=collection,
limit=limit,
batch_size=batch_size,
@ -154,8 +154,14 @@ def main():
parser.add_argument(
'-t', '--token',
default=os.getenv("TRUSTGRAPH_TOKEN"),
help='API bearer token (default: TRUSTGRAPH_TOKEN env var)',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
@ -174,12 +180,6 @@ def main():
help='Read SPARQL query from file (use - for stdin)',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})',
)
parser.add_argument(
'-C', '--collection',
default=default_collection,
@ -228,11 +228,11 @@ def main():
token=args.token,
flow_id=args.flow_id,
query=query,
user=args.user,
collection=args.collection,
limit=args.limit,
batch_size=args.batch_size,
output_format=args.format,
workspace=args.workspace,
)
except Exception as e:

View file

@ -13,7 +13,9 @@ from tabulate import tabulate
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
def format_output(data, output_format):
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def format_output(data, output_format, token=None, workspace="default"):
"""Format structured query response data in the specified format"""
if not data:
return "No data returned"
@ -79,11 +81,11 @@ def format_table_data(rows, table_name, output_format):
else:
return json.dumps({table_name: rows}, indent=2)
def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'):
def structured_query(url, flow_id, question, collection='default', output_format='table', token=None, workspace="default"):
api = Api(url).flow().id(flow_id)
api = Api(url, token=token, workspace=workspace).flow().id(flow_id)
resp = api.structured_query(question=question, user=user, collection=collection)
resp = api.structured_query(question=question, collection=collection)
# Check for errors
if "error" in resp and resp["error"]:
@ -119,6 +121,17 @@ def main():
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
@ -132,12 +145,6 @@ def main():
help='Natural language question to execute',
)
parser.add_argument(
'--user',
default='trustgraph',
help='Cassandra keyspace identifier (default: trustgraph)'
)
parser.add_argument(
'--collection',
default='default',
@ -159,9 +166,12 @@ def main():
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
output_format=args.format,
token=args.token,
workspace = args.workspace,
)
except Exception as e:

Some files were not shown because too many files have changed in this diff Show more