diff --git a/.gitignore b/.gitignore index daeba074..32942156 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ trustgraph-parquet/trustgraph/parquet_version.py trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py +trustgraph/trustgraph/trustgraph_version.py vertexai/ \ No newline at end of file diff --git a/Makefile b/Makefile index 85f10fdd..0f0f37b2 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ container-bedrock container-vertexai \ container-hf container-ocr \ container-unstructured container-mcp -some-containers: container-base container-flow +some-containers: container-base container-flow container-unstructured push: ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} diff --git a/docs/tech-specs/data-ownership-model.md b/docs/tech-specs/data-ownership-model.md new file mode 100644 index 00000000..ea94ec46 --- /dev/null +++ b/docs/tech-specs/data-ownership-model.md @@ -0,0 +1,309 @@ +--- +layout: default +title: "Data Ownership and Information Separation" +parent: "Tech Specs" +--- + +# Data Ownership and Information Separation + +## Purpose + +This document defines the logical ownership model for data in +TrustGraph: what the artefacts are, who owns them, and how they relate +to each other. + +The IAM spec ([iam.md](iam.md)) describes authentication and +authorisation mechanics. This spec addresses the prior question: what +are the boundaries around data, and who owns what? + +## Concepts + +### Workspace + +A workspace is the primary isolation boundary. It represents an +organisation, team, or independent operating unit. All data belongs to +exactly one workspace. Cross-workspace access is never permitted through +the API. + +A workspace owns: +- Source documents +- Flows (processing pipeline definitions) +- Knowledge cores (stored extraction output) +- Collections (organisational units for extracted knowledge) + +### Collection + +A collection is an organisational unit within a workspace. It groups +extracted knowledge produced from source documents. A workspace can +have multiple collections, allowing: + +- Processing the same documents with different parameters or models. +- Maintaining separate knowledge bases for different purposes. +- Deleting extracted knowledge without deleting source documents. + +Collections do not own source documents. A source document exists at the +workspace level and can be processed into multiple collections. + +### Source document + +A source document (PDF, text file, etc.) is raw input uploaded to the +system. Documents belong to the workspace, not to a specific collection. + +This is intentional. A document is an asset that exists independently +of how it is processed. The same PDF might be processed into multiple +collections with different chunking parameters or extraction models. +Tying a document to a single collection would force re-upload for each +collection. + +### Flow + +A flow defines a processing pipeline: which models to use, what +parameters to apply (chunk size, temperature, etc.), and how processing +services are connected. Flows belong to a workspace. + +The processing services themselves (document-decoder, chunker, +embeddings, LLM completion, etc.) are shared infrastructure — they serve +all workspaces. Each flow has its own queues, keeping data from +different workspaces and flows separate as it moves through the +pipeline. + +Different workspaces can define different flows. Workspace A might use +GPT-5.2 with a chunk size of 2000, while workspace B uses Claude with a +chunk size of 1000. + +### Prompts + +Prompts are templates that control how the LLM behaves during knowledge +extraction and query answering. They belong to a workspace, allowing +different workspaces to have different extraction strategies, response +styles, or domain-specific instructions. + +### Ontology + +An ontology defines the concepts, entities, and relationships that the +extraction pipeline looks for in source documents. Ontologies belong to +a workspace. A medical workspace might define ontologies around diseases, +symptoms, and treatments, while a legal workspace defines ontologies +around statutes, precedents, and obligations. + +### Schemas + +Schemas define structured data types for extraction. They specify what +fields to extract, their types, and how they relate. Schemas belong to +a workspace, as different workspaces extract different structured +information from their documents. + +### Tools, tool services, and MCP tools + +Tools define capabilities available to agents: what actions they can +take, what external services they can call. Tool services configure how +tools connect to backend services. MCP tools configure connections to +remote MCP servers, including authentication tokens. All belong to a +workspace. + +### Agent patterns and agent task types + +Agent patterns define agent behaviour strategies (how an agent reasons, +what steps it follows). Agent task types define the kinds of tasks +agents can perform. Both belong to a workspace, as different workspaces +may have different agent configurations. + +### Token costs + +Token cost definitions specify pricing for LLM token usage per model. +These belong to a workspace since different workspaces may use different +models or have different billing arrangements. + +### Flow blueprints + +Flow blueprints are templates for creating flows. They define the +default pipeline structure and parameters. Blueprints belong to a +workspace, allowing workspaces to define custom processing templates. + +### Parameter types + +Parameter types define the kinds of parameters that flows accept (e.g. +"llm-model", "temperature"), including their defaults and validation +rules. They belong to a workspace since workspaces that define custom +flows need to define the parameter types those flows use. + +### Interface descriptions + +Interface descriptions define the connection points of a flow — what +queues and topics it uses. They belong to a workspace since they +describe workspace-owned flows. + +### Knowledge core + +A knowledge core is a stored snapshot of extracted knowledge (triples +and graph embeddings). Knowledge cores belong to a workspace and can be +loaded into any collection within that workspace. + +Knowledge cores serve as a portable extraction output. You process +documents through a flow, the pipeline produces triples and embeddings, +and the results can be stored as a knowledge core. That core can later +be loaded into a different collection or reloaded after a collection is +cleared. + +### Extracted knowledge + +Extracted knowledge is the live, queryable content within a collection: +triples in the knowledge graph, graph embeddings, and document +embeddings. It is the product of processing source documents through a +flow into a specific collection. + +Extracted knowledge is scoped to a workspace and a collection. It +cannot exist without both. + +### Processing record + +A processing record tracks which source document was processed, through +which flow, into which collection. It links the source document +(workspace-scoped) to the extracted knowledge (workspace + collection +scoped). + +## Ownership summary + +| Artefact | Owned by | Shared across collections? | +|----------|----------|---------------------------| +| Workspaces | Global (platform) | N/A | +| User accounts | Global (platform) | N/A | +| API keys | Global (platform) | N/A | +| Source documents | Workspace | Yes | +| Flows | Workspace | N/A | +| Flow blueprints | Workspace | N/A | +| Prompts | Workspace | N/A | +| Ontologies | Workspace | N/A | +| Schemas | Workspace | N/A | +| Tools | Workspace | N/A | +| Tool services | Workspace | N/A | +| MCP tools | Workspace | N/A | +| Agent patterns | Workspace | N/A | +| Agent task types | Workspace | N/A | +| Token costs | Workspace | N/A | +| Parameter types | Workspace | N/A | +| Interface descriptions | Workspace | N/A | +| Knowledge cores | Workspace | Yes — can be loaded into any collection | +| Collections | Workspace | N/A | +| Extracted knowledge | Workspace + collection | No | +| Processing records | Workspace + collection | No | + +## Scoping summary + +### Global (system-level) + +A small number of artefacts exist outside any workspace: + +- **Workspace registry** — the list of workspaces itself +- **User accounts** — users reference a workspace but are not owned by + one +- **API keys** — belong to users, not workspaces + +These are managed by the IAM layer and exist at the platform level. + +### Workspace-owned + +All other configuration and data is workspace-owned: + +- Flow definitions and parameters +- Flow blueprints +- Prompts +- Ontologies +- Schemas +- Tools, tool services, and MCP tools +- Agent patterns and agent task types +- Token costs +- Parameter types +- Interface descriptions +- Collection definitions +- Knowledge cores +- Source documents +- Collections and their extracted knowledge + +## Relationship between artefacts + +``` +Platform (global) + | + +-- Workspaces + | | + +-- User accounts (each assigned to a workspace) + | | + +-- API keys (belong to users) + +Workspace + | + +-- Source documents (uploaded, unprocessed) + | + +-- Flows (pipeline definitions: models, parameters, queues) + | + +-- Flow blueprints (templates for creating flows) + | + +-- Prompts (LLM instruction templates) + | + +-- Ontologies (entity and relationship definitions) + | + +-- Schemas (structured data type definitions) + | + +-- Tools, tool services, MCP tools (agent capabilities) + | + +-- Agent patterns and agent task types (agent behaviour) + | + +-- Token costs (LLM pricing per model) + | + +-- Parameter types (flow parameter definitions) + | + +-- Interface descriptions (flow connection points) + | + +-- Knowledge cores (stored extraction snapshots) + | + +-- Collections + | + +-- Extracted knowledge (triples, embeddings) + | + +-- Processing records (links documents to collections) +``` + +A typical workflow: + +1. A source document is uploaded to the workspace. +2. A flow defines how to process it (which models, what parameters). +3. The document is processed through the flow into a collection. +4. Processing records track what was processed. +5. Extracted knowledge (triples, embeddings) is queryable within the + collection. +6. Optionally, the extracted knowledge is stored as a knowledge core + for later reuse. + +## Implementation notes + +The current codebase uses a `user` field in message metadata and storage +partition keys to identify the workspace. The `collection` field +identifies the collection within that workspace. The IAM spec describes +how the gateway maps authenticated credentials to a workspace identity +and sets these fields. + +For details on how each storage backend implements this scoping, see: + +- [Entity-Centric Graph](entity-centric-graph.md) — Cassandra KG schema +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Collection Management](collection-management.md) + +### Known inconsistencies in current implementation + +- **Pipeline intermediate tables** do not include collection in their + partition keys. Re-processing the same document into a different + collection may overwrite intermediate state. +- **Processing metadata** stores collection in the row payload but not + in the partition key, making collection-based queries inefficient. +- **Upload sessions** are keyed by upload ID, not workspace. The + gateway should validate workspace ownership before allowing + operations on upload sessions. + +## References + +- [Identity and Access Management](iam.md) +- [Collection Management](collection-management.md) +- [Entity-Centric Graph](entity-centric-graph.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Multi-Tenant Support](multi-tenant-support.md) diff --git a/docs/tech-specs/flow-class-definition.md b/docs/tech-specs/flow-class-definition.md index 94229b72..3a81bf71 100644 --- a/docs/tech-specs/flow-class-definition.md +++ b/docs/tech-specs/flow-class-definition.md @@ -20,8 +20,8 @@ Defines shared service processors that are instantiated once per flow blueprint. ```json "class": { "service-name:{class}": { - "request": "queue-pattern:{class}", - "response": "queue-pattern:{class}", + "request": "queue-pattern:{workspace}:{class}", + "response": "queue-pattern:{workspace}:{class}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -31,11 +31,11 @@ Defines shared service processors that are instantiated once per flow blueprint. ``` **Characteristics:** -- Shared across all flow instances of the same class +- Shared across all flow instances of the same class within a workspace - Typically expensive or stateless services (LLMs, embedding models) -- Use `{class}` template variable for queue naming +- Use `{workspace}` and `{class}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}` +- Examples: `embeddings:{workspace}:{class}`, `text-completion:{workspace}:{class}` ### 2. Flow Section Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors. @@ -43,8 +43,8 @@ Defines flow-specific processors that are instantiated for each individual flow ```json "flow": { "processor-name:{id}": { - "input": "queue-pattern:{id}", - "output": "queue-pattern:{id}", + "input": "queue-pattern:{workspace}:{id}", + "output": "queue-pattern:{workspace}:{id}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -56,9 +56,9 @@ Defines flow-specific processors that are instantiated for each individual flow **Characteristics:** - Unique instance per flow - Handle flow-specific data and state -- Use `{id}` template variable for queue naming +- Use `{workspace}` and `{id}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}` +- Examples: `chunker:{workspace}:{id}`, `pdf-decoder:{workspace}:{id}` ### 3. Interfaces Section Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication. @@ -68,8 +68,8 @@ Interfaces can take two forms: **Fire-and-Forget Pattern** (single queue): ```json "interfaces": { - "document-load": "persistent://tg/flow/document-load:{id}", - "triples-store": "persistent://tg/flow/triples-store:{id}" + "document-load": "persistent://tg/flow/{workspace}:document-load:{id}", + "triples-store": "persistent://tg/flow/{workspace}:triples-store:{id}" } ``` @@ -77,8 +77,8 @@ Interfaces can take two forms: ```json "interfaces": { "embeddings": { - "request": "non-persistent://tg/request/embeddings:{class}", - "response": "non-persistent://tg/response/embeddings:{class}" + "request": "non-persistent://tg/request/{workspace}:embeddings:{class}", + "response": "non-persistent://tg/response/{workspace}:embeddings:{class}" } } ``` @@ -117,6 +117,16 @@ Additional information about the flow blueprint: ### System Variables +#### {workspace} +- Replaced with the workspace identifier +- Isolates queue names between workspaces so that two workspaces + starting the same flow do not share queues +- Must be included in all queue name patterns to ensure workspace + isolation +- Example: `ws-acme`, `ws-globex` +- All blueprint templates must include `{workspace}` in queue name + patterns + #### {id} - Replaced with the unique flow instance identifier - Creates isolated resources for each flow diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md new file mode 100644 index 00000000..5de50749 --- /dev/null +++ b/docs/tech-specs/iam.md @@ -0,0 +1,858 @@ +--- +layout: default +title: "Identity and Access Management" +parent: "Tech Specs" +--- + +# Identity and Access Management + +## Problem Statement + +TrustGraph has no meaningful identity or access management. The system +relies on a single shared gateway token for authentication and an +honour-system `user` query parameter for data isolation. This creates +several problems: + +- **No user identity.** There are no user accounts, no login, and no way + to know who is making a request. The `user` field in message metadata + is a caller-supplied string with no validation — any client can claim + to be any user. + +- **No access control.** A valid gateway token grants unrestricted access + to every endpoint, every user's data, every collection, and every + administrative operation. There is no way to limit what an + authenticated caller can do. + +- **No credential isolation.** All callers share one static token. There + is no per-user credential, no token expiration, and no rotation + mechanism. Revoking access means changing the shared token, which + affects all callers. + +- **Data isolation is unenforced.** Storage backends (Cassandra, Neo4j, + Qdrant) filter queries by `user` and `collection`, but the gateway + does not prevent a caller from specifying another user's identity. + Cross-user data access is trivial. + +- **No audit trail.** There is no logging of who accessed what. Without + user identity, audit logging is impossible. + +These gaps make the system unsuitable for multi-user deployments, +multi-tenant SaaS, or any environment where access needs to be +controlled or audited. + +## Current State + +### Authentication + +The API gateway supports a single shared token configured via the +`GATEWAY_SECRET` environment variable or `--api-token` CLI argument. If +unset, authentication is disabled entirely. When enabled, every HTTP +endpoint requires an `Authorization: Bearer ` header. WebSocket +connections pass the token as a query parameter. + +Implementation: `trustgraph-flow/trustgraph/gateway/auth.py` + +```python +class Authenticator: + def __init__(self, token=None, allow_all=False): + self.token = token + self.allow_all = allow_all + + def permitted(self, token, roles): + if self.allow_all: return True + if self.token != token: return False + return True +``` + +The `roles` parameter is accepted but never evaluated. All authenticated +requests have identical privileges. + +MCP tool configurations support an optional per-tool `auth-token` for +service-to-service authentication with remote MCP servers. These are +static, system-wide tokens — not per-user credentials. See +[mcp-tool-bearer-token.md](mcp-tool-bearer-token.md) for details. + +### User identity + +The `user` field is passed explicitly by the caller as a query parameter +(e.g. `?user=trustgraph`) or set by CLI tools. It flows through the +system in the core `Metadata` dataclass: + +```python +@dataclass +class Metadata: + id: str = "" + root: str = "" + user: str = "" + collection: str = "" +``` + +There is no user registration, login, user database, or session +management. + +### Data isolation + +The `user` + `collection` pair is used at the storage layer to partition +data: + +- **Cassandra**: queries filter by `user` and `collection` columns +- **Neo4j**: queries filter by `user` and `collection` properties +- **Qdrant**: vector search filters by `user` and `collection` metadata + +| Layer | Isolation mechanism | Enforced by | +|-------|-------------------|-------------| +| Gateway | Single shared token | `Authenticator` class | +| Message metadata | `user` + `collection` fields | Caller (honour system) | +| Cassandra | Column filters on `user`, `collection` | Query layer | +| Neo4j | Property filters on `user`, `collection` | Query layer | +| Qdrant | Metadata filters on `user`, `collection` | Query layer | +| Pub/sub topics | Per-flow topic namespacing | Flow service | + +The storage-layer isolation depends on all queries correctly filtering by +`user` and `collection`. There is no gateway-level enforcement preventing +a caller from querying another user's data by passing a different `user` +parameter. + +### Configuration and secrets + +| Setting | Source | Default | Purpose | +|---------|--------|---------|---------| +| `GATEWAY_SECRET` | Env var | Empty (auth disabled) | Gateway bearer token | +| `--api-token` | CLI arg | None | Gateway bearer token (overrides env) | +| `PULSAR_API_KEY` | Env var | None | Pub/sub broker auth | +| MCP `auth-token` | Config service | None | Per-tool MCP server auth | + +No secrets are encrypted at rest. The gateway token and MCP tokens are +stored and transmitted in plaintext (aside from any transport-layer +encryption such as TLS). + +### Capabilities that do not exist + +- Per-user authentication (JWT, OAuth, SAML, API keys per user) +- User accounts or user management +- Role-based access control (RBAC) +- Attribute-based access control (ABAC) +- Per-user or per-workspace API keys +- Token expiration or rotation +- Session management +- Per-user rate limiting +- Audit logging of user actions +- Permission checks preventing cross-user data access +- Multi-workspace credential isolation + +### Key files + +| File | Purpose | +|------|---------| +| `trustgraph-flow/trustgraph/gateway/auth.py` | Authenticator class | +| `trustgraph-flow/trustgraph/gateway/service.py` | Gateway init, token config | +| `trustgraph-flow/trustgraph/gateway/endpoint/*.py` | Per-endpoint auth checks | +| `trustgraph-base/trustgraph/schema/core/metadata.py` | `Metadata` dataclass with `user` field | + +## Technical Design + +### Design principles + +- **Auth at the edge.** The gateway is the single enforcement point. + Internal services trust the gateway and do not re-authenticate. + This avoids distributing credential validation across dozens of + microservices. + +- **Identity from credentials, not from callers.** The gateway derives + user identity from authentication credentials. Callers can no longer + self-declare their identity via query parameters. + +- **Workspace isolation by default.** Every authenticated user belongs to + a workspace. All data operations are scoped to that workspace. + Cross-workspace access is not possible through the API. + +- **Extensible API contract.** The API accepts an optional workspace + parameter on every request. This allows the same protocol to support + single-workspace deployments today and multi-workspace extensions in + the future without breaking changes. + +- **Simple roles, not fine-grained permissions.** A small number of + predefined roles controls what operations a user can perform. This is + sufficient for the current API surface and avoids the complexity of + per-resource permission management. + +### Authentication + +The gateway supports two credential types. Both are carried as a Bearer +token in the `Authorization` header for HTTP requests. The gateway +distinguishes them by format. + +For WebSocket connections, credentials are not passed in the URL or +headers. Instead, the client authenticates after connecting by sending +an auth message as the first frame: + +``` +Client: opens WebSocket to /api/v1/socket +Server: accepts connection (unauthenticated state) +Client: sends {"type": "auth", "token": "tg_abc123..."} +Server: validates token + success → {"type": "auth-ok", "workspace": "acme"} + failure → {"type": "auth-failed", "error": "invalid token"} +``` + +The server rejects all non-auth messages until authentication succeeds. +The socket remains open on auth failure, allowing the client to retry +with a different token without reconnecting. The client can also send +a new auth message at any time to re-authenticate — for example, to +refresh an expiring JWT or to switch workspace. The +resolved identity (user, workspace, roles) is updated on each +successful auth. + +#### API keys + +For programmatic access: CLI tools, scripts, and integrations. + +- Opaque tokens (e.g. `tg_a1b2c3d4e5f6...`). Not JWTs — short, + simple, easy to paste into CLI tools and headers. +- Each user has one or more API keys. +- Keys are stored hashed (SHA-256 with salt) in the IAM service. The + plaintext key is returned once at creation time and cannot be + retrieved afterwards. +- Keys can be revoked individually without affecting other users. +- Keys optionally have an expiry date. Expired keys are rejected. + +On each request, the gateway resolves an API key by: + +1. Hashing the token. +2. Checking a local cache (hash → user/workspace/roles). +3. On cache miss, calling the IAM service to resolve. +4. Caching the result with a short TTL (e.g. 60 seconds). + +Revoked keys stop working when the cache entry expires. No push +invalidation is needed. + +#### JWTs (login sessions) + +For interactive access via the UI or WebSocket connections. + +- A user logs in with username and password. The gateway forwards the + request to the IAM service, which validates the credentials and + returns a signed JWT. +- The JWT carries the user ID, workspace, and roles as claims. +- The gateway validates JWTs locally using the IAM service's public + signing key — no service call needed on subsequent requests. +- Token expiry is enforced by standard JWT validation at the time the + request (or WebSocket connection) is made. +- For long-lived WebSocket connections, the JWT is validated at connect + time only. The connection remains authenticated for its lifetime. + +The IAM service manages the signing key. The gateway fetches the public +key at startup (or on first JWT encounter) and caches it. + +#### Login endpoint + +``` +POST /api/v1/auth/login +{ + "username": "alice", + "password": "..." +} +→ { + "token": "eyJ...", + "expires": "2026-04-20T19:00:00Z" +} +``` + +The gateway forwards this to the IAM service, which validates +credentials and returns a signed JWT. The gateway returns the JWT to +the caller. + +#### IAM service delegation + +The gateway stays thin. Its authentication logic is: + +1. Extract Bearer token from header (or query param for WebSocket). +2. If the token has JWT format (dotted structure), validate the + signature locally and extract claims. +3. Otherwise, treat as an API key: hash it and check the local cache. + On cache miss, call the IAM service to resolve. +4. If neither succeeds, return 401. + +All user management, key management, credential validation, and token +signing logic lives in the IAM service. The gateway is a generic +enforcement point that can be replaced without changing the IAM +service. + +#### No legacy token support + +The existing `GATEWAY_SECRET` shared token is removed. All +authentication uses API keys or JWTs. On first start, the bootstrap +process creates a default workspace and admin user with an initial API +key. + +### User identity + +A user belongs to exactly one workspace. The design supports extending +this to multi-workspace access in the future (see +[Extension points](#extension-points)). + +A user record contains: + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique user identifier (UUID) | +| `name` | string | Display name | +| `email` | string | Email address (optional) | +| `workspace` | string | Workspace the user belongs to | +| `roles` | list[string] | Assigned roles (e.g. `["reader"]`) | +| `enabled` | bool | Whether the user can authenticate | +| `created` | datetime | Account creation timestamp | + +The `workspace` field maps to the existing `user` field in `Metadata`. +This means the storage-layer isolation (Cassandra, Neo4j, Qdrant +filtering by `user` + `collection`) works without changes — the gateway +sets the `user` metadata field to the authenticated user's workspace. + +### Workspaces + +A workspace is an isolated data boundary. Users belong to a workspace, +and all data operations are scoped to it. Workspaces map to the existing +`user` field in `Metadata` and the corresponding Cassandra keyspace, +Qdrant collection prefix, and Neo4j property filters. + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique workspace identifier | +| `name` | string | Display name | +| `enabled` | bool | Whether the workspace is active | +| `created` | datetime | Creation timestamp | + +All data operations are scoped to a workspace. The gateway determines +the effective workspace for each request as follows: + +1. If the request includes a `workspace` parameter, validate it against + the user's assigned workspace. + - If it matches, use it. + - If it does not match, return 403. (This could be extended to + check a workspace access grant list.) +2. If no `workspace` parameter is provided, use the user's assigned + workspace. + +The gateway sets the `user` field in `Metadata` to the effective +workspace ID, replacing the caller-supplied `?user=` query parameter. + +This design ensures forward compatibility. Clients that pass a +workspace parameter will work unchanged if multi-workspace support is +added later. Requests for an unassigned workspace get a clear 403 +rather than silent misbehaviour. + +### Roles and access control + +Three roles with fixed permissions: + +| Role | Data operations | Admin operations | System | +|------|----------------|-----------------|--------| +| `reader` | Query knowledge graph, embeddings, RAG | None | None | +| `writer` | All reader operations + load documents, manage collections | None | None | +| `admin` | All writer operations | Config, flows, collection management, user management | Metrics | + +Role checks happen at the gateway before dispatching to backend +services. Each endpoint declares the minimum role required: + +| Endpoint pattern | Minimum role | +|-----------------|--------------| +| `GET /api/v1/socket` (queries) | `reader` | +| `POST /api/v1/librarian` | `writer` | +| `POST /api/v1/flow/*/import/*` | `writer` | +| `POST /api/v1/config` | `admin` | +| `GET /api/v1/flow/*` | `admin` | +| `GET /api/metrics` | `admin` | + +Roles are hierarchical: `admin` implies `writer`, which implies +`reader`. + +### IAM service + +The IAM service is a new backend service that manages all identity and +access data. It is the authority for users, workspaces, API keys, and +credentials. The gateway delegates to it. + +#### Data model + +``` +iam_workspaces ( + id text PRIMARY KEY, + name text, + enabled boolean, + created timestamp +) + +iam_users ( + id text PRIMARY KEY, + workspace text, + name text, + email text, + password_hash text, + roles set, + enabled boolean, + created timestamp +) + +iam_api_keys ( + key_hash text PRIMARY KEY, + user_id text, + name text, + expires timestamp, + created timestamp +) +``` + +A secondary index on `iam_api_keys.user_id` supports listing a user's +keys. + +#### Responsibilities + +- User CRUD (create, list, update, disable) +- Workspace CRUD (create, list, update, disable) +- API key management (create, revoke, list) +- API key resolution (hash → user/workspace/roles) +- Credential validation (username/password → signed JWT) +- JWT signing key management (initialise, rotate) +- Bootstrap (create default workspace and admin user on first start) + +#### Communication + +The IAM service communicates via the standard request/response pub/sub +pattern, the same as the config service. The gateway calls it to +resolve API keys and to handle login requests. User management +operations (create user, revoke key, etc.) also go through the IAM +service. + +### Gateway changes + +The current `Authenticator` class is replaced with a thin authentication +middleware that delegates to the IAM service: + +For HTTP requests: + +1. Extract Bearer token from the `Authorization` header. +2. If the token has JWT format (dotted structure): + - Validate signature locally using the cached public key. + - Extract user ID, workspace, and roles from claims. +3. Otherwise, treat as an API key: + - Hash the token and check the local cache. + - On cache miss, call the IAM service to resolve. + - Cache the result (user/workspace/roles) with a short TTL. +4. If neither succeeds, return 401. +5. If the user or workspace is disabled, return 403. +6. Check the user's role against the endpoint's minimum role. If + insufficient, return 403. +7. Resolve the effective workspace: + - If the request includes a `workspace` parameter, validate it + against the user's assigned workspace. Return 403 on mismatch. + - If no `workspace` parameter, use the user's assigned workspace. +8. Set the `user` field in the request context to the effective + workspace ID. This propagates through `Metadata` to all downstream + services. + +For WebSocket connections: + +1. Accept the connection in an unauthenticated state. +2. Wait for an auth message (`{"type": "auth", "token": "..."}`). +3. Validate the token using the same logic as steps 2-7 above. +4. On success, attach the resolved identity to the connection and + send `{"type": "auth-ok", ...}`. +5. On failure, send `{"type": "auth-failed", ...}` but keep the + socket open. +6. Reject all non-auth messages until authentication succeeds. +7. Accept new auth messages at any time to re-authenticate. + +### CLI changes + +CLI tools authenticate with API keys: + +- `--api-key` argument on all CLI tools, replacing `--api-token`. +- `tg-create-workspace`, `tg-list-workspaces` for workspace management. +- `tg-create-user`, `tg-list-users`, `tg-disable-user` for user + management. +- `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` for + key management. +- `--workspace` argument on tools that operate on workspace-scoped + data. +- The API key is passed as a Bearer token in the same way as the + current shared token, so the transport protocol is unchanged. + +### Audit logging + +With user identity established, the gateway logs: + +- Timestamp, user ID, workspace, endpoint, HTTP method, response status. +- Audit logs are written to the standard logging output (structured + JSON). Integration with external log aggregation (Loki, ELK) is a + deployment concern, not an application concern. + +### Config service changes + +All configuration is workspace-scoped (see +[data-ownership-model.md](data-ownership-model.md)). The config service +needs to support this. + +#### Schema change + +The config table adds workspace as a key dimension: + +``` +config ( + workspace text, + class text, + key text, + value text, + PRIMARY KEY ((workspace, class), key) +) +``` + +#### Request format + +Config requests add a `workspace` field at the request level. The +existing `(type, key)` structure is unchanged within each workspace. + +**Get:** +```json +{ + "operation": "get", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +**Put:** +```json +{ + "operation": "put", + "workspace": "workspace-a", + "values": [{"type": "prompt", "key": "rag-prompt", "value": "..."}] +} +``` + +**List (all keys of a type within a workspace):** +```json +{ + "operation": "list", + "workspace": "workspace-a", + "type": "prompt" +} +``` + +**Delete:** +```json +{ + "operation": "delete", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +The workspace is set by: + +- **Gateway** — from the authenticated user's workspace for API-facing + requests. +- **Internal services** — explicitly, based on `Metadata.user` from + the message being processed, or `_system` for operational config. + +#### System config namespace + +Processor-level operational config (logging levels, connection strings, +resource limits) is not workspace-specific. This stays in a reserved +`_system` workspace that is not associated with any user workspace. +Services read system config at startup without needing a workspace +context. + +#### Config change notifications + +The config notify mechanism pushes change notifications via pub/sub +when config is updated. A single update may affect multiple workspaces +and multiple config types. The notification message carries a dict of +changes keyed by config type, with each value being the list of +affected workspaces: + +```json +{ + "version": 42, + "changes": { + "prompt": ["workspace-a", "workspace-b"], + "schema": ["workspace-a"] + } +} +``` + +System config changes use the reserved `_system` workspace: + +```json +{ + "version": 43, + "changes": { + "logging": ["_system"] + } +} +``` + +This structure is keyed by type because handlers register by type. A +handler registered for `prompt` looks up `"prompt"` directly and gets +the list of affected workspaces — no iteration over unrelated types. + +#### Config change handlers + +The current `on_config` hook mechanism needs two modes to support shared +processing services: + +- **Workspace-scoped handlers** — notify when a config type changes in a + specific workspace. The handler looks up its registered type in the + changes dict and checks if its workspace is in the list. Used by the + gateway and by services that serve a single workspace. + +- **Global handlers** — notify when a config type changes in any + workspace. The handler looks up its registered type in the changes + dict and gets the full list of affected workspaces. Used by shared + processing services (prompt-rag, agent manager, etc.) that serve all + workspaces. Each workspace in the list tells the handler which cache + entry to update rather than reloading everything. + +#### Per-workspace config caching + +Shared services that handle messages from multiple workspaces maintain a +per-workspace config cache. When a message arrives, the service looks up +the config for the workspace identified in `Metadata.user`. If the +workspace is not yet cached, the service fetches its config on demand. +Config change notifications update the relevant cache entry. + +### Flow and queue isolation + +Flows are workspace-owned. When two workspaces start flows with the same +name and blueprint, their queues must be separate to prevent data +mixing. + +Flow blueprint templates currently use `{id}` (flow instance ID) and +`{class}` (blueprint name) as template variables in queue names. A new +`{workspace}` variable is added so queue names include the workspace: + +**Current queue names (no workspace isolation):** +``` +flow:tg:document-load:{id} → flow:tg:document-load:default +request:tg:embeddings:{class} → request:tg:embeddings:everything +``` + +**With workspace isolation:** +``` +flow:tg:{workspace}:document-load:{id} → flow:tg:ws-a:document-load:default +request:tg:{workspace}:embeddings:{class} → request:tg:ws-a:embeddings:everything +``` + +The flow service substitutes `{workspace}` from the authenticated +workspace when starting a flow, the same way it substitutes `{id}` and +`{class}` today. + +Processing services are shared infrastructure — they consume from +workspace-specific queues but are not themselves workspace-aware. The +workspace is carried in `Metadata.user` on every message, so services +know which workspace's data they are processing. + +Blueprint templates need updating to include `{workspace}` in all queue +name patterns. For migration, the flow service can inject the workspace +into queue names automatically if the template does not include +`{workspace}`, defaulting to the legacy behaviour for existing +blueprints. + +See [flow-class-definition.md](flow-class-definition.md) for the full +blueprint template specification. + +### What changes and what doesn't + +**Changes:** + +| Component | Change | +|-----------|--------| +| `gateway/auth.py` | Replace `Authenticator` with new auth middleware | +| `gateway/service.py` | Initialise IAM client, configure JWT validation | +| `gateway/endpoint/*.py` | Add role requirement per endpoint | +| Metadata propagation | Gateway sets `user` from workspace, ignores query param | +| Config service | Add workspace dimension to config schema | +| Config table | `PRIMARY KEY ((workspace, class), key)` | +| Config request/response schema | Add `workspace` field | +| Config notify messages | Include workspace ID in change notifications | +| `on_config` handlers | Support workspace-scoped and global modes | +| Shared services | Per-workspace config caching | +| Flow blueprints | Add `{workspace}` template variable to queue names | +| Flow service | Substitute `{workspace}` when starting flows | +| CLI tools | New user management commands, `--api-key` argument | +| Cassandra schema | New `iam_workspaces`, `iam_users`, `iam_api_keys` tables | + +**Does not change:** + +| Component | Reason | +|-----------|--------| +| Internal service-to-service pub/sub | Services trust the gateway | +| `Metadata` dataclass | `user` field continues to carry workspace identity | +| Storage-layer isolation | Same `user` + `collection` filtering | +| Message serialisation | No schema changes | + +### Migration + +This is a breaking change. Existing deployments must be reconfigured: + +1. `GATEWAY_SECRET` is removed. Authentication requires API keys or + JWT login tokens. +2. The `?user=` query parameter is removed. Workspace identity comes + from authentication. +3. On first start, the IAM service bootstraps a default workspace and + admin user. The initial API key is output to the service log. +4. Operators create additional workspaces and users via CLI tools. +5. Flow blueprints must be updated to include `{workspace}` in queue + name patterns. +6. Config data must be migrated to include the workspace dimension. + +## Extension points + +The design includes deliberate extension points for future capabilities. +These are not implemented but the architecture does not preclude them: + +- **Multi-workspace access.** Users could be granted access to + additional workspaces beyond their primary assignment. The workspace + validation step checks a grant list instead of a single assignment. +- **Rules-based access control.** A separate access control service + could evaluate fine-grained policies (per-collection permissions, + operation-level restrictions, time-based access). The gateway + delegates authorisation decisions to this service. +- **External identity provider integration.** SAML, LDAP, and OIDC + flows (group mapping, claims-based role assignment) could be added + to the IAM service. +- **Cross-workspace administration.** A `superadmin` role for platform + operators who manage multiple workspaces. +- **Delegated workspace provisioning.** APIs for programmatic workspace + creation and user onboarding. + +These extensions are additive — they extend the validation logic +without changing the request/response protocol. The gateway can be +replaced with an alternative implementation that supports these +capabilities while the IAM service and backend services remain +unchanged. + +## Implementation plan + +Workspace support is a prerequisite for auth — users are assigned to +workspaces, config is workspace-scoped, and flows use workspace in +queue names. Implementing workspaces first allows the structural changes +to be tested end-to-end without auth complicating debugging. + +### Phase 1: Workspace support (no auth) + +All workspace-scoped data and processing changes. The system works with +workspaces but no authentication — callers pass workspace as a +parameter, honour system. This allows full end-to-end testing: multiple +workspaces with separate flows, config, queues, and data. + +#### Config service + +- Update config client API to accept a workspace parameter on all + requests +- Update config storage schema to add workspace as a key dimension +- Update config notification API to report changes as a dict of + type → workspace list +- Update the processor base class to understand workspaces in config + notifications (workspace-scoped and global handler modes) +- Update all processors to implement workspace-aware config handling + (per-workspace config caching, on-demand fetch) + +#### Flow and queue isolation + +- Update flow blueprints to include `{workspace}` in all queue name + patterns +- Update the flow service to substitute `{workspace}` when starting + flows +- Update all built-in blueprints to include `{workspace}` + +#### CLI tools (workspace support) + +- Add `--workspace` argument to CLI tools that operate on + workspace-scoped data +- Add `tg-create-workspace`, `tg-list-workspaces` commands + +### Phase 2: Authentication and access control + +With workspaces working, add the IAM service and lock down the gateway. + +#### IAM service + +A new service handling identity and access management on behalf of the +API gateway: + +- Add workspace table support (CRUD, enable/disable) +- Add user table support (CRUD, enable/disable, workspace assignment) +- Add roles support (role assignment, role validation) +- Add API key support (create, revoke, list, hash storage) +- Add ability to initialise a JWT signing key for token grants +- Add token grant endpoint: user/password login returns a signed JWT +- Add bootstrap/initialisation mechanism: ability to set the signing + key and create the initial workspace + admin user on first start + +#### API gateway integration + +- Add IAM middleware to the API gateway replacing the current + `Authenticator` +- Add local JWT validation (public key from IAM service) +- Add API key resolution with local cache (hash → user/workspace/roles, + cache miss calls IAM service, short TTL) +- Add login endpoint forwarding to IAM service +- Add workspace resolution: validate requested workspace against user + assignment +- Add role-based endpoint access checks +- Add user management API endpoints (forwarded to IAM service) +- Add audit logging (user ID, workspace, endpoint, method, status) +- WebSocket auth via first-message protocol (auth message after + connect, socket stays open on failure, re-auth supported) + +#### CLI tools (auth support) + +- Add `tg-create-user`, `tg-list-users`, `tg-disable-user` commands +- Add `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` + commands +- Replace `--api-token` with `--api-key` on existing CLI tools + +#### Bootstrap and cutover + +- Create default workspace and admin user on first start if IAM tables + are empty +- Remove `GATEWAY_SECRET` and `?user=` query parameter support + +## Design Decisions + +### IAM data store + +IAM data is stored in dedicated Cassandra tables owned by the IAM +service, not in the config service. Reasons: + +- **Security isolation.** The config service has a broad, generic + protocol. An access control failure on the config service could + expose credentials. A dedicated IAM service with a purpose-built + protocol limits the attack surface and makes security auditing + clearer. +- **Data model fit.** IAM needs indexed lookups (API key hash → user, + list keys by user). The config service's `(workspace, type, key) → + value` model stores opaque JSON strings with no secondary indexes. +- **Scope.** IAM data is global (workspaces, users, keys). Config is + workspace-scoped. Mixing global and workspace-scoped data in the + same store adds complexity. +- **Audit.** IAM operations (key creation, revocation, login attempts) + are security events that should be logged separately from general + config changes. + +## Deferred to future design + +- **OIDC integration.** External identity provider support (SAML, LDAP, + OIDC) is left for future implementation. The extension points section + describes where this fits architecturally. +- **API key scoping.** API keys could be scoped to specific collections + within a workspace rather than granting workspace-wide access. To be + designed when the need arises. +- **tg-init-trustgraph** only initialises a single workspace. + +## References + +- [Data Ownership and Information Separation](data-ownership-model.md) +- [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md) +- [Multi-Tenant Support Specification](multi-tenant-support.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 2442bf10..8ce2d467 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration: async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config): """Test basic agent integration with structured query tool""" # Arrange - Load tool configuration - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Create agent request request = AgentRequest( @@ -119,6 +119,7 @@ Args: { # Mock flow parameter in agent_processor.on_request flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -146,7 +147,7 @@ Args: { async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config): """Test agent handling of structured query errors""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Find data from a table that doesn't exist using structured query.", @@ -199,6 +200,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -221,7 +223,7 @@ Args: { async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): """Test agent using structured query in multi-step reasoning""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="First find all customers from California, then tell me how many orders they have made.", @@ -279,6 +281,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -313,7 +316,7 @@ Args: { } } - await agent_processor.on_tools_config(tool_config_with_collection, "v1") + await agent_processor.on_tools_config("default", tool_config_with_collection, "v1") request = AgentRequest( question="Query the sales data for recent transactions.", @@ -371,6 +374,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -394,10 +398,10 @@ Args: { async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): """Test that structured query tool arguments are properly validated""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Check that the tool was registered with correct arguments - tools = agent_processor.agent.tools + tools = agent_processor.agents["default"].tools assert "structured-query" in tools structured_tool = tools["structured-query"] @@ -414,7 +418,7 @@ Args: { async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config): """Test that structured query results are properly formatted for agent consumption""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Get customer information and format it nicely.", @@ -482,6 +486,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) diff --git a/tests/integration/test_nlp_query_integration.py b/tests/integration/test_nlp_query_integration.py index 16c4543e..08bf1e77 100644 --- a/tests/integration/test_nlp_query_integration.py +++ b/tests/integration/test_nlp_query_integration.py @@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration: ) # Set up schemas - proc.schemas = sample_schemas + proc.schemas = {"default": dict(sample_schemas)} # Mock the client method proc.client = MagicMock() @@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration: } # Act - Update configuration - await integration_processor.on_schema_config(new_schema_config, "v2") + await integration_processor.on_schema_config("default", new_schema_config, "v2") # Arrange - Test query using new schema request = QuestionToStructuredQueryRequest( @@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration: await integration_processor.on_message(msg, consumer, flow) # Assert - assert "inventory" in integration_processor.schemas + assert "inventory" in integration_processor.schemas["default"] response_call = flow_response.send.call_args response = response_call[0][0] assert response.detected_schemas == ["inventory"] @@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration: graphql_generation_template="custom-graphql-generator" ) - custom_processor.schemas = sample_schemas + custom_processor.schemas = {"default": dict(sample_schemas)} custom_processor.client = MagicMock() request = QuestionToStructuredQueryRequest( @@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration: ] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)] ) - integration_processor.schemas.update(large_schema_set) + integration_processor.schemas["default"].update(large_schema_set) request = QuestionToStructuredQueryRequest( question="Show me data from table_05 and table_12", @@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration: msg.properties.return_value = {"id": f"concurrent-test-{i}"} flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index 22ba9a3f..32e74436 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.mark.asyncio @@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration: processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Act - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas - + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas + # Verify customer schema - customer_schema = processor.schemas["customer_records"] + customer_schema = ws_schemas["customer_records"] assert customer_schema.name == "customer_records" assert len(customer_schema.fields) == 4 - + # Verify product schema - product_schema = processor.schemas["product_catalog"] + product_schema = ws_schemas["product_catalog"] assert product_schema.name == "product_catalog" assert len(product_schema.fields) == 4 @@ -237,7 +239,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic customer data chunk metadata = Metadata( @@ -304,7 +306,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic product data chunk metadata = Metadata( @@ -368,7 +370,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create multiple test chunks chunks_data = [ @@ -431,19 +433,21 @@ class TestObjectExtractionServiceIntegration: "customer_records": integration_config["schema"]["customer_records"] } } - await processor.on_schema_config(initial_config, version=1) - - assert len(processor.schemas) == 1 - assert "customer_records" in processor.schemas - assert "product_catalog" not in processor.schemas - + await processor.on_schema_config("default", initial_config, version=1) + + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 1 + assert "customer_records" in ws_schemas + assert "product_catalog" not in ws_schemas + # Act - Reload with full configuration - await processor.on_schema_config(integration_config, version=2) - + await processor.on_schema_config("default", integration_config, version=2) + # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas @pytest.mark.asyncio async def test_error_resilience_integration(self, integration_config): @@ -474,10 +478,11 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() failing_flow.side_effect = failing_context_router + failing_flow.workspace = "default" processor.flow = failing_flow # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create test chunk metadata = Metadata(id="error-test", user="test", collection="test") @@ -510,7 +515,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create chunk with rich metadata original_metadata = Metadata( diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py index a1414e2d..84a3cdec 100644 --- a/tests/integration/test_prompt_streaming_integration.py +++ b/tests/integration/test_prompt_streaming_integration.py @@ -87,6 +87,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.fixture @@ -109,7 +110,7 @@ class TestPromptStreaming: def prompt_processor_streaming(self, mock_prompt_manager): """Create Prompt processor with streaming support""" processor = MagicMock() - processor.manager = mock_prompt_manager + processor.managers = {"default": mock_prompt_manager} processor.config_key = "prompt" # Bind the actual on_request method @@ -248,6 +249,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", @@ -341,6 +343,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index a2b8ae08..de6bbeeb 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + @pytest.mark.integration class TestRowsCassandraIntegration: """Integration tests for Cassandra row storage with unified table""" @@ -125,8 +136,8 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert "customer_records" in processor.schemas + await processor.on_schema_config("default", config, version=1) + assert "customer_records" in processor.schemas["default"] # Step 2: Process an ExtractedObject test_obj = ExtractedObject( @@ -149,7 +160,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify Cassandra interactions assert mock_cluster.connect.called @@ -209,8 +220,8 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert len(processor.schemas) == 2 + await processor.on_schema_config("default", config, version=1) + assert len(processor.schemas["default"]) == 2 # Process objects for different schemas product_obj = ExtractedObject( @@ -233,7 +244,7 @@ class TestRowsCassandraIntegration: for obj in [product_obj, order_obj]: msg = MagicMock() msg.value.return_value = obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # All data goes into the same unified rows table table_calls = [call for call in mock_session.execute.call_args_list @@ -256,15 +267,17 @@ class TestRowsCassandraIntegration: with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): # Schema with multiple indexed fields - processor.schemas["indexed_data"] = RowSchema( - name="indexed_data", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True), - Field(name="status", type="string", size=50, indexed=True), - Field(name="description", type="string", size=200) # Not indexed - ] - ) + processor.schemas["default"] = { + "indexed_data": RowSchema( + name="indexed_data", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True), + Field(name="status", type="string", size=50, indexed=True), + Field(name="description", type="string", size=200) # Not indexed + ] + ) + } test_obj = ExtractedObject( metadata=Metadata(id="t1", user="test", collection="test"), @@ -282,7 +295,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 data inserts (one per indexed field: id, category, status) rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -342,7 +355,7 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Process batch object with multiple values batch_obj = ExtractedObject( @@ -376,7 +389,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify unified table creation table_calls = [call for call in mock_session.execute.call_args_list @@ -396,10 +409,12 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["empty_test"] = RowSchema( - name="empty_test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) + processor.schemas["default"] = { + "empty_test": RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + } # Process empty batch object empty_obj = ExtractedObject( @@ -413,7 +428,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = empty_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should not create any data insert statements for empty batch # (partition registration may still happen) @@ -428,14 +443,16 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["map_test"] = RowSchema( - name="map_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="name", type="string", size=100), - Field(name="count", type="integer", size=0) - ] - ) + processor.schemas["default"] = { + "map_test": RowSchema( + name="map_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="count", type="integer", size=0) + ] + ) + } test_obj = ExtractedObject( metadata=Metadata(id="t1", user="test", collection="test"), @@ -448,7 +465,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert uses map for data rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -473,13 +490,15 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["partition_test"] = RowSchema( - name="partition_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True) - ] - ) + processor.schemas["default"] = { + "partition_test": RowSchema( + name="partition_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True) + ] + ) + } test_obj = ExtractedObject( metadata=Metadata(id="t1", user="test", collection="my_collection"), @@ -492,7 +511,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify partition registration partition_inserts = [call for call in mock_session.execute.call_args_list diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index a717901b..29b4464d 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_schema_configuration_and_generation(self, processor, sample_schema_config): """Test schema configuration loading and GraphQL schema generation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Verify schemas were loaded assert len(processor.schemas) == 2 @@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config): """Test Cassandra connection and dynamic table creation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra processor.connect_cassandra() @@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config): """Test inserting data and querying via GraphQL""" # Load schema and connect - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Setup test data @@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_query_with_filters(self, processor, sample_schema_config): """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "test_user" @@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_error_handling(self, processor, sample_schema_config): """Test GraphQL error handling for invalid queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Test invalid field query invalid_query = ''' @@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_message_processing_integration(self, processor, sample_schema_config): """Test full message processing workflow""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create mock message @@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_concurrent_queries(self, processor, sample_schema_config): """Test handling multiple concurrent GraphQL queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create multiple query tasks @@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(initial_config, version=1) + await processor.on_schema_config("default", initial_config, version=1) assert len(processor.schemas) == 1 assert "simple" in processor.schemas @@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(updated_config, version=2) + await processor.on_schema_config("default", updated_config, version=2) # Verify updated schemas assert len(processor.schemas) == 2 @@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_large_result_set_handling(self, processor, sample_schema_config): """Test handling of large query result sets""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "large_test_user" @@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance: } } - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Measure query execution time start_time = time.time() diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index bb58e5ee..5240b62c 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to call think and observe callbacks async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -58,6 +61,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() @@ -129,6 +133,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to return Final directly async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -148,6 +155,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py index 65cdb542..382244fb 100644 --- a/tests/unit/test_agent/test_tool_service_lifecycle.py +++ b/tests/unit/test_agent/test_tool_service_lifecycle.py @@ -241,7 +241,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return "tool result" svc.invoke_tool = mock_invoke @@ -260,6 +260,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') @@ -280,7 +281,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return {"data": [1, 2, 3]} svc.invoke_tool = mock_invoke @@ -298,6 +299,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -317,7 +319,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def failing_invoke(name, params): + async def failing_invoke(workspace, name, params): raise RuntimeError("tool broke") svc.invoke_tool = failing_invoke @@ -330,6 +332,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -350,7 +353,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def rate_limited(name, params): + async def rate_limited(workspace, name, params): raise TooManyRequests("slow down") svc.invoke_tool = rate_limited @@ -362,6 +365,7 @@ class TestToolServiceOnRequest: flow = MagicMock() flow.producer = {"response": AsyncMock()} flow.name = "test-flow" + flow.workspace = "default" with pytest.raises(TooManyRequests): await svc.on_request(msg, MagicMock(), flow) @@ -376,7 +380,8 @@ class TestToolServiceOnRequest: received = {} - async def capture_invoke(name, params): + async def capture_invoke(workspace, name, params): + received["workspace"] = workspace received["name"] = name received["params"] = params return "ok" @@ -390,6 +395,7 @@ class TestToolServiceOnRequest: flow = lambda name: mock_pub flow.producer = {"response": mock_pub} flow.name = "f" + flow.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest( diff --git a/tests/unit/test_base/test_async_processor_config.py b/tests/unit/test_base/test_async_processor_config.py index f1a83fef..3dffd775 100644 --- a/tests/unit/test_base/test_async_processor_config.py +++ b/tests/unit/test_base/test_async_processor_config.py @@ -1,17 +1,14 @@ """ Tests for AsyncProcessor config notify pattern: - register_config_handler with types filtering -- on_config_notify version comparison and type matching -- fetch_config with short-lived client -- fetch_and_apply_config retry logic +- on_config_notify version comparison, type/workspace matching +- fetch_and_apply_config retry logic over per-workspace fetches """ import pytest from unittest.mock import AsyncMock, MagicMock, patch, Mock -from trustgraph.schema import Term, IRI, LITERAL -# Patch heavy dependencies before importing AsyncProcessor @pytest.fixture def processor(): """Create an AsyncProcessor with mocked dependencies.""" @@ -68,6 +65,13 @@ class TestRegisterConfigHandler: assert len(processor.config_handlers) == 2 +def _notify_msg(version, changes): + """Build a Mock config-notify message with given version and changes dict.""" + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestOnConfigNotify: @pytest.mark.asyncio @@ -77,9 +81,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=3, types=["prompt"]) - + msg = _notify_msg(3, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -91,9 +93,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=5, types=["prompt"]) - + msg = _notify_msg(5, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -105,9 +105,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=2, types=["schema"]) - + msg = _notify_msg(2, {"schema": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -121,40 +119,36 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - # Mock fetch_config - mock_config = {"prompt": {"key": "value"}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={"key": "value"}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - handler.assert_called_once_with(mock_config, 2) + handler.assert_called_once_with( + "default", {"prompt": {"key": "value"}}, 2 + ) assert processor.config_version == 2 @pytest.mark.asyncio - async def test_handler_without_types_always_called(self, processor): + async def test_handler_without_types_ignored_on_notify(self, processor): + """Handlers registered without types never fire on notifications.""" processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) # No types = all + processor.register_config_handler(handler) # No types - mock_config = {"anything": {}} - with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 2) - ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["whatever"]) + msg = _notify_msg(2, {"whatever": ["default"]}) + await processor.on_config_notify(msg, None, None) - await processor.on_config_notify(msg, None, None) - - handler.assert_called_once_with(mock_config, 2) + handler.assert_not_called() + # Version still advances past the notify + assert processor.config_version == 2 @pytest.mark.asyncio async def test_mixed_handlers_type_filtering(self, processor): @@ -168,156 +162,149 @@ class TestOnConfigNotify: processor.register_config_handler(schema_handler, types=["schema"]) processor.register_config_handler(all_handler) - mock_config = {"prompt": {}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - prompt_handler.assert_called_once() + prompt_handler.assert_called_once_with( + "default", {"prompt": {}}, 2 + ) schema_handler.assert_not_called() - all_handler.assert_called_once() + all_handler.assert_not_called() @pytest.mark.asyncio - async def test_empty_types_invokes_all(self, processor): - """Empty types list (startup signal) should invoke all handlers.""" + async def test_multi_workspace_notify_invokes_handler_per_ws( + self, processor + ): + """Notify affecting multiple workspaces invokes handler once per workspace.""" processor.config_version = 1 - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) - mock_config = {} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=[]) - + msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]}) await processor.on_config_notify(msg, None, None) - h1.assert_called_once() - h2.assert_called_once() + assert handler.call_count == 2 + called_workspaces = {c.args[0] for c in handler.call_args_list} + assert called_workspaces == {"ws1", "ws2"} @pytest.mark.asyncio async def test_fetch_failure_handled(self, processor): processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) + processor.register_config_handler(handler, types=["prompt"]) + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - side_effect=RuntimeError("Connection failed") + side_effect=RuntimeError("Connection failed"), ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) # Should not raise await processor.on_config_notify(msg, None, None) handler.assert_not_called() -class TestFetchConfig: - - @pytest.mark.asyncio - async def test_fetch_returns_config_and_version(self, processor): - mock_resp = Mock() - mock_resp.error = None - mock_resp.config = {"prompt": {"key": "val"}} - mock_resp.version = 42 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - config, version = await processor.fetch_config() - - assert config == {"prompt": {"key": "val"}} - assert version == 42 - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_raises_on_error_response(self, processor): - mock_resp = Mock() - mock_resp.error = Mock(message="not found") - mock_resp.config = {} - mock_resp.version = 0 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(RuntimeError, match="Config error"): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_stops_client_on_exception(self, processor): - mock_client = AsyncMock() - mock_client.request.side_effect = TimeoutError("timeout") - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(TimeoutError): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - class TestFetchAndApplyConfig: @pytest.mark.asyncio - async def test_applies_config_to_all_handlers(self, processor): - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + async def test_applies_config_per_workspace(self, processor): + """Startup fetch invokes handler once per workspace affected.""" + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return { + "ws1": {"k": "v1"}, + "ws2": {"k": "v2"}, + }, 10 - mock_config = {"prompt": {}, "schema": {}} with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 10) + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, ): await processor.fetch_and_apply_config() - # On startup, all handlers are invoked regardless of type - h1.assert_called_once_with(mock_config, 10) - h2.assert_called_once_with(mock_config, 10) + assert h.call_count == 2 + call_map = {c.args[0]: c.args[1] for c in h.call_args_list} + assert call_map["ws1"] == {"prompt": {"k": "v1"}} + assert call_map["ws2"] == {"prompt": {"k": "v2"}} assert processor.config_version == 10 @pytest.mark.asyncio - async def test_retries_on_failure(self, processor): - call_count = 0 - mock_config = {"prompt": {}} + async def test_handler_without_types_skipped_at_startup(self, processor): + """Handlers registered without types fetch nothing at startup.""" + typed = AsyncMock() + untyped = AsyncMock() + processor.register_config_handler(typed, types=["prompt"]) + processor.register_config_handler(untyped) - async def mock_fetch(): + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return {"default": {}}, 1 + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ): + await processor.fetch_and_apply_config() + + typed.assert_called_once() + untyped.assert_not_called() + + @pytest.mark.asyncio + async def test_retries_on_failure(self, processor): + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + call_count = 0 + + async def fake_fetch_all(client, config_type): nonlocal call_count call_count += 1 if call_count < 3: raise RuntimeError("not ready") - return mock_config, 5 + return {"default": {"k": "v"}}, 5 - with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \ - patch('asyncio.sleep', new_callable=AsyncMock): + mock_client = AsyncMock() + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ), patch('asyncio.sleep', new_callable=AsyncMock): await processor.fetch_and_apply_config() assert call_count == 3 assert processor.config_version == 5 + h.assert_called_once_with( + "default", {"prompt": {"k": "v"}}, 5 + ) diff --git a/tests/unit/test_base/test_flow_base_modules.py b/tests/unit/test_base/test_flow_base_modules.py index 5bbd7a18..758edcff 100644 --- a/tests/unit/test_base/test_flow_base_modules.py +++ b/tests/unit/test_base/test_flow_base_modules.py @@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs(): spec_two = MagicMock() processor = MagicMock(specifications=[spec_one, spec_two]) - flow = Flow("processor-1", "flow-a", processor, {"answer": 42}) + flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42}) assert flow.id == "processor-1" assert flow.name == "flow-a" + assert flow.workspace == "default" assert flow.producer == {} assert flow.consumer == {} assert flow.parameter == {} @@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs(): def test_flow_start_and_stop_visit_all_consumers(): consumer_one = AsyncMock() consumer_two = AsyncMock() - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.consumer = {"one": consumer_one, "two": consumer_two} asyncio.run(flow.start()) @@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers(): def test_flow_call_returns_values_in_priority_order(): - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.producer["shared"] = "producer-value" flow.consumer["consumer-only"] = "consumer-value" flow.consumer["shared"] = "consumer-value" diff --git a/tests/unit/test_base/test_flow_parameter_specs.py b/tests/unit/test_base/test_flow_parameter_specs.py index c813d66c..da7e9736 100644 --- a/tests/unit/test_base/test_flow_parameter_specs.py +++ b/tests/unit/test_base/test_flow_parameter_specs.py @@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase): flow_defn = {'config': 'test-config'} # Act - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) # Assert - Flow should be created with access to processor specifications - mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn) + mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn) # The flow should have access to the processor's specifications # (The exact mechanism depends on Flow implementation) diff --git a/tests/unit/test_base/test_flow_processor.py b/tests/unit/test_base/test_flow_processor.py index 36a05ec2..350a8b43 100644 --- a/tests/unit/test_base/test_flow_processor.py +++ b/tests/unit/test_base/test_flow_processor.py @@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): flow_name = 'test-flow' flow_defn = {'config': 'test-config'} - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) - assert flow_name in processor.flows + assert ("default", flow_name) in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', flow_name, processor, flow_defn + 'test-processor', flow_name, "default", processor, flow_defn ) mock_flow.start.assert_called_once() @@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): mock_flow_class.return_value = mock_flow flow_name = 'test-flow' - await processor.start_flow(flow_name, {'config': 'test-config'}) + await processor.start_flow("default", flow_name, {'config': 'test-config'}) - await processor.stop_flow(flow_name) + await processor.stop_flow("default", flow_name) - assert flow_name not in processor.flows + assert ("default", flow_name) not in processor.flows mock_flow.stop.assert_called_once() @with_async_processor_patches @@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): processor = FlowProcessor(**config) - await processor.stop_flow('non-existent-flow') + await processor.stop_flow("default", 'non-existent-flow') assert processor.flows == {} @@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) - assert 'test-flow' in processor.flows + assert ("default", 'test-flow') in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', 'test-flow', processor, + 'test-processor', 'test-flow', "default", processor, {'config': 'test-config'} ) mock_flow.start.assert_called_once() @@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): 'other-data': 'some-value' } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data1, version=1) + await processor.on_configure_flows("default", config_data1, version=1) config_data2 = { 'processor:test-processor': { @@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data2, version=2) + await processor.on_configure_flows("default", config_data2, version=2) - assert 'flow1' not in processor.flows + assert ("default", 'flow1') not in processor.flows mock_flow1.stop.assert_called_once() - assert 'flow2' in processor.flows + assert ("default", 'flow2') in processor.flows mock_flow2.start.assert_called_once() @with_async_processor_patches diff --git a/tests/unit/test_cli/test_config_commands.py b/tests/unit/test_cli/test_config_commands.py index 68ae1a54..b5b74688 100644 --- a/tests/unit/test_cli/test_config_commands.py +++ b/tests/unit/test_cli/test_config_commands.py @@ -109,7 +109,8 @@ class TestListConfigItems: url='http://custom.com', config_type='prompt', format_type='json', - token=None + token=None, + workspace='default' ) def test_list_main_uses_defaults(self): @@ -128,7 +129,8 @@ class TestListConfigItems: url='http://localhost:8088/', config_type='prompt', format_type='text', - token=None + token=None, + workspace='default' ) @@ -196,7 +198,8 @@ class TestGetConfigItem: config_type='prompt', key='template-1', format_type='json', - token=None + token=None, + workspace='default' ) @@ -253,7 +256,8 @@ class TestPutConfigItem: config_type='prompt', key='new-template', value='Custom prompt: {input}', - token=None + token=None, + workspace='default' ) def test_put_main_with_stdin_arg(self): @@ -278,7 +282,8 @@ class TestPutConfigItem: config_type='prompt', key='stdin-template', value=stdin_content, - token=None + token=None, + workspace='default' ) def test_put_main_mutually_exclusive_args(self): @@ -334,7 +339,8 @@ class TestDeleteConfigItem: url='http://custom.com', config_type='prompt', key='old-template', - token=None + token=None, + workspace='default' ) diff --git a/tests/unit/test_cli/test_load_knowledge.py b/tests/unit/test_cli/test_load_knowledge.py index 63045ef9..853d96cf 100644 --- a/tests/unit/test_cli/test_load_knowledge.py +++ b/tests/unit/test_cli/test_load_knowledge.py @@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob . # Verify Api was created with correct parameters mock_api_class.assert_called_once_with( url="http://test.example.com/", - token="test-token" + token="test-token", + workspace="default" ) # Verify bulk client was obtained diff --git a/tests/unit/test_cli/test_tool_commands.py b/tests/unit/test_cli/test_tool_commands.py index 9c204614..72624d27 100644 --- a/tests/unit/test_cli/test_tool_commands.py +++ b/tests/unit/test_cli/test_tool_commands.py @@ -145,7 +145,8 @@ class TestSetToolStructuredQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_set_main_structured_query_no_arguments_needed(self): @@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_valid_types_includes_row_embeddings_query(self): @@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery: show_main() - mock_show.assert_called_once_with(url='http://custom.com', token=None) + mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default') class TestShowToolsRowEmbeddingsQuery: diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 80c27fe8..d4b86074 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -28,10 +28,12 @@ def mock_flow_config(): """Mock flow configuration.""" mock_config = Mock() mock_config.flows = { - "test-flow": { - "interfaces": { - "triples-store": {"flow": "test-triples-queue"}, - "graph-embeddings-store": {"flow": "test-ge-queue"} + "test-user": { + "test-flow": { + "interfaces": { + "triples-store": {"flow": "test-triples-queue"}, + "graph-embeddings-store": {"flow": "test-ge-queue"} + } } } } diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py index 45a22e48..7cf89ff7 100644 --- a/tests/unit/test_embeddings/test_row_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert 'customers' in processor.schemas - assert processor.schemas['customers'].name == 'customers' - assert len(processor.schemas['customers'].fields) == 3 + assert 'customers' in processor.schemas["default"] + assert processor.schemas["default"]['customers'].name == 'customers' + assert len(processor.schemas["default"]['customers'].fields) == 3 async def test_on_schema_config_handles_missing_type(self): """Test that missing schema type is handled gracefully""" @@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): 'other_type': {} } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert processor.schemas == {} + assert processor.schemas.get("default", {}) == {} async def test_on_message_drops_unknown_collection(self): """Test that messages for unknown collections are dropped""" @@ -325,14 +325,16 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): processor.known_collections[('test_user', 'test_collection')] = {} # Set up schema - processor.schemas['customers'] = RowSchema( - name='customers', - description='Customer records', - fields=[ - Field(name='id', type='text', primary=True), - Field(name='name', type='text', indexed=True), - ] - ) + processor.schemas["default"] = { + 'customers': RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + ] + ) + } metadata = MagicMock() metadata.user = 'test_user' @@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): return MagicMock() mock_flow = MagicMock(side_effect=flow_factory) + mock_flow.workspace = "default" await processor.on_message(mock_msg, MagicMock(), mock_flow) diff --git a/tests/unit/test_gateway/test_config_receiver.py b/tests/unit/test_gateway/test_config_receiver.py index 90ba8d33..56e96178 100644 --- a/tests/unit/test_gateway/test_config_receiver.py +++ b/tests/unit/test_gateway/test_config_receiver.py @@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader ConfigReceiver.config_loader = Mock() +def _notify(version, changes): + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestConfigReceiver: """Test cases for ConfigReceiver class""" @@ -47,98 +53,70 @@ class TestConfigReceiver: assert handler2 in config_receiver.flow_handlers @pytest.mark.asyncio - async def test_on_config_notify_new_version(self): - """Test on_config_notify triggers fetch for newer version""" + async def test_on_config_notify_new_version_fetches_per_workspace(self): + """Notify with newer version fetches each affected workspace.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 - # Mock fetch_and_apply fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with newer version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 1 + msg = _notify(2, {"flow": ["ws1", "ws2"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert set(fetch_calls) == {"ws1", "ws2"} + assert config_receiver.config_version == 2 @pytest.mark.asyncio async def test_on_config_notify_old_version_ignored(self): - """Test on_config_notify ignores older versions""" + """Older-version notifies are ignored.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 5 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with older version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=3, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 0 + msg = _notify(3, {"flow": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] @pytest.mark.asyncio async def test_on_config_notify_irrelevant_types_ignored(self): - """Test on_config_notify ignores types the gateway doesn't care about""" + """Notifies without flow changes advance version but skip fetch.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with non-flow type - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["prompt"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - # Version should be updated but no fetch - assert len(fetch_calls) == 0 + msg = _notify(2, {"prompt": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] assert config_receiver.config_version == 2 - @pytest.mark.asyncio - async def test_on_config_notify_flow_type_triggers_fetch(self): - """Test on_config_notify fetches for flow-related types""" - mock_backend = Mock() - config_receiver = ConfigReceiver(mock_backend) - config_receiver.config_version = 1 - - fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - - for type_name in ["flow"]: - fetch_calls.clear() - config_receiver.config_version = 1 - - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=[type_name]) - - await config_receiver.on_config_notify(mock_msg, None, None) - - assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}" - @pytest.mark.asyncio async def test_on_config_notify_exception_handling(self): - """Test on_config_notify handles exceptions gracefully""" + """on_config_notify swallows exceptions from message decode.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Create notify message that causes an exception mock_msg = Mock() mock_msg.value.side_effect = Exception("Test exception") @@ -146,19 +124,18 @@ class TestConfigReceiver: await config_receiver.on_config_notify(mock_msg, None, None) @pytest.mark.asyncio - async def test_fetch_and_apply_with_new_flows(self): - """Test fetch_and_apply starts new flows""" + async def test_fetch_and_apply_workspace_starts_new_flows(self): + """fetch_and_apply_workspace starts newly-configured flows.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Mock _create_config_client to return a mock client mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow1": '{"name": "test_flow_1"}', - "flow2": '{"name": "test_flow_2"}' + "flow2": '{"name": "test_flow_2"}', } } @@ -167,36 +144,39 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) start_flow_calls = [] - async def mock_start_flow(id, flow): - start_flow_calls.append((id, flow)) + + async def mock_start_flow(workspace, id, flow): + start_flow_calls.append((workspace, id, flow)) + config_receiver.start_flow = mock_start_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") assert config_receiver.config_version == 5 - assert "flow1" in config_receiver.flows - assert "flow2" in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" in config_receiver.flows["default"] assert len(start_flow_calls) == 2 + assert all(c[0] == "default" for c in start_flow_calls) @pytest.mark.asyncio - async def test_fetch_and_apply_with_removed_flows(self): - """Test fetch_and_apply stops removed flows""" + async def test_fetch_and_apply_workspace_stops_removed_flows(self): + """fetch_and_apply_workspace stops flows no longer configured.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate with existing flows config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config now only has flow1 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { - "flow1": '{"name": "test_flow_1"}' + "flow1": '{"name": "test_flow_1"}', } } @@ -205,20 +185,22 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) stop_flow_calls = [] - async def mock_stop_flow(id, flow): - stop_flow_calls.append((id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_flow_calls.append((workspace, id, flow)) + config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" in config_receiver.flows - assert "flow2" not in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" not in config_receiver.flows["default"] assert len(stop_flow_calls) == 1 - assert stop_flow_calls[0][0] == "flow2" + assert stop_flow_calls[0][:2] == ("default", "flow2") @pytest.mark.asyncio - async def test_fetch_and_apply_with_no_flows(self): - """Test fetch_and_apply with empty config""" + async def test_fetch_and_apply_workspace_with_no_flows(self): + """Empty workspace config clears any local flow state.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) @@ -231,88 +213,100 @@ class TestConfigReceiver: mock_client.request.return_value = mock_resp config_receiver._create_config_client = Mock(return_value=mock_client) - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert config_receiver.flows == {} + assert config_receiver.flows.get("default", {}) == {} assert config_receiver.config_version == 1 @pytest.mark.asyncio async def test_start_flow_with_handlers(self): - """Test start_flow method with multiple handlers""" + """start_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.start_flow = Mock() + handler1.start_flow = AsyncMock() handler2 = Mock() - handler2.start_flow = Mock() + handler2.start_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler1.start_flow.assert_called_once_with("flow1", flow_data) - handler2.start_flow.assert_called_once_with("flow1", flow_data) + handler1.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_start_flow_with_handler_exception(self): - """Test start_flow method handles handler exceptions""" + """Handler exceptions in start_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.start_flow = Mock(side_effect=Exception("Handler error")) + handler.start_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler.start_flow.assert_called_once_with("flow1", flow_data) + handler.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handlers(self): - """Test stop_flow method with multiple handlers""" + """stop_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.stop_flow = Mock() + handler1.stop_flow = AsyncMock() handler2 = Mock() - handler2.stop_flow = Mock() + handler2.stop_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler1.stop_flow.assert_called_once_with("flow1", flow_data) - handler2.stop_flow.assert_called_once_with("flow1", flow_data) + handler1.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handler_exception(self): - """Test stop_flow method handles handler exceptions""" + """Handler exceptions in stop_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.stop_flow = Mock(side_effect=Exception("Handler error")) + handler.stop_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler.stop_flow.assert_called_once_with("flow1", flow_data) + handler.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @patch('asyncio.create_task') @pytest.mark.asyncio @@ -329,25 +323,25 @@ class TestConfigReceiver: mock_create_task.assert_called_once() @pytest.mark.asyncio - async def test_fetch_and_apply_mixed_flow_operations(self): - """Test fetch_and_apply with mixed add/remove operations""" + async def test_fetch_and_apply_workspace_mixed_flow_operations(self): + """fetch_and_apply_workspace adds, keeps and removes flows in one pass.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config removes flow1, keeps flow2, adds flow3 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow2": '{"name": "test_flow_2"}', - "flow3": '{"name": "test_flow_3"}' + "flow3": '{"name": "test_flow_3"}', } } @@ -358,20 +352,22 @@ class TestConfigReceiver: start_calls = [] stop_calls = [] - async def mock_start_flow(id, flow): - start_calls.append((id, flow)) - async def mock_stop_flow(id, flow): - stop_calls.append((id, flow)) + async def mock_start_flow(workspace, id, flow): + start_calls.append((workspace, id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_calls.append((workspace, id, flow)) config_receiver.start_flow = mock_start_flow config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" not in config_receiver.flows - assert "flow2" in config_receiver.flows - assert "flow3" in config_receiver.flows + ws_flows = config_receiver.flows["default"] + assert "flow1" not in ws_flows + assert "flow2" in ws_flows + assert "flow3" in ws_flows assert len(start_calls) == 1 - assert start_calls[0][0] == "flow3" + assert start_calls[0][:2] == ("default", "flow3") assert len(stop_calls) == 1 - assert stop_calls[0][0] == "flow1" + assert stop_calls[0][:2] == ("default", "flow1") diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 4ebcb5b9..f091a46d 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -72,10 +72,10 @@ class TestDispatcherManager: flow_data = {"name": "test_flow", "steps": []} - await manager.start_flow("flow1", flow_data) - - assert "flow1" in manager.flows - assert manager.flows["flow1"] == flow_data + await manager.start_flow("default", "flow1", flow_data) + + assert ("default", "flow1") in manager.flows + assert manager.flows[("default", "flow1")] == flow_data @pytest.mark.asyncio async def test_stop_flow(self): @@ -86,11 +86,11 @@ class TestDispatcherManager: # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} - manager.flows["flow1"] = flow_data - - await manager.stop_flow("flow1", flow_data) - - assert "flow1" not in manager.flows + manager.flows[("default", "flow1")] = flow_data + + await manager.stop_flow("default", "flow1", flow_data) + + assert ("default", "flow1") not in manager.flows def test_dispatch_global_service_returns_wrapper(self): """Test dispatch_global_service returns DispatcherWrapper""" @@ -275,12 +275,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -290,7 +290,7 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - + params = {"flow": "test_flow", "kind": "triples"} result = await manager.process_flow_import("ws", "running", params) @@ -326,12 +326,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers: mock_dispatchers.__contains__.return_value = False @@ -348,12 +348,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -404,7 +404,7 @@ class TestDispatcherManager: params = {"flow": "test_flow", "kind": "agent"} result = await manager.process_flow_service("data", "responder", params) - manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent") + manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent") assert result == "flow_result" @pytest.mark.asyncio @@ -415,14 +415,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Add flow to the flows dictionary - manager.flows["test_flow"] = {"services": {"agent": {}}} - + manager.flows[("default", "test_flow")] = {"services": {"agent": {}}} + # Pre-populate with existing dispatcher mock_dispatcher = Mock() mock_dispatcher.process = AsyncMock(return_value="cached_result") - manager.dispatchers[("test_flow", "agent")] = mock_dispatcher - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") mock_dispatcher.process.assert_called_once_with("data", "responder") assert result == "cached_result" @@ -435,7 +435,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -443,7 +443,7 @@ class TestDispatcherManager: } } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() mock_dispatcher = Mock() @@ -452,23 +452,23 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, request_queue="agent_request_queue", response_queue="agent_response_queue", timeout=120, - consumer="api-gateway-test_flow-agent-request", - subscriber="api-gateway-test_flow-agent-request" + consumer="api-gateway-default-test_flow-agent-request", + subscriber="api-gateway-default-test_flow-agent-request" ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher assert result == "new_result" @pytest.mark.asyncio @@ -479,26 +479,26 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-load": {"flow": "text_load_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = True - + mock_dispatcher_class = Mock() mock_dispatcher = Mock() mock_dispatcher.start = AsyncMock() mock_dispatcher.process = AsyncMock(return_value="sender_result") mock_dispatcher_class.return_value = mock_dispatcher mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, @@ -506,9 +506,9 @@ class TestDispatcherManager: ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher assert result == "sender_result" @pytest.mark.asyncio @@ -519,7 +519,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) with pytest.raises(RuntimeError, match="Invalid flow"): - await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_unsupported_kind_by_flow(self): @@ -529,14 +529,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow without agent interface - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-completion": {"request": "req", "response": "resp"} } } - + with pytest.raises(RuntimeError, match="This kind not supported by flow"): - await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_invalid_kind(self): @@ -546,7 +546,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow with interface but unsupported kind - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } @@ -558,7 +558,7 @@ class TestDispatcherManager: mock_sender_dispatchers.__contains__.return_value = False with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind") @pytest.mark.asyncio async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): @@ -608,7 +608,7 @@ class TestDispatcherManager: mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -630,7 +630,7 @@ class TestDispatcherManager: mock_rr_dispatchers.__contains__.return_value = True results = await asyncio.gather(*[ - manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") for _ in range(5) ]) @@ -638,5 +638,5 @@ class TestDispatcherManager: "Dispatcher class instantiated more than once — duplicate consumer bug" ) assert mock_dispatcher.start.call_count == 1 - assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher assert all(r == "result" for r in results) \ No newline at end of file diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index 184560f0..08256ec0 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -239,7 +239,7 @@ def _make_processor(tools=None): agent = MagicMock() agent.tools = tools or {} agent.additional_context = "" - processor.agent = agent + processor.agents = {"default": agent} processor.aggregator = MagicMock() return processor @@ -254,6 +254,7 @@ def _make_flow(): return producers[name] flow = MagicMock(side_effect=factory) + flow.workspace = "default" return flow @@ -299,7 +300,7 @@ class TestAgentReactDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -433,7 +434,7 @@ class TestAgentPlanDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -537,7 +538,7 @@ class TestAgentSupervisorDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index c0d399c3..fca25242 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic: """Test parsing of schema configuration""" processor = MagicMock() processor.schemas = {} + processor.schema_builders = {} + processor.graphql_schemas = {} processor.config_key = "schema" - processor.schema_builder = MagicMock() - processor.schema_builder.clear = MagicMock() - processor.schema_builder.add_schema = MagicMock() - processor.schema_builder.build = MagicMock(return_value=MagicMock()) + processor.query_cassandra = MagicMock() processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic: } # Process config - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Verify schema was loaded - assert "customer" in processor.schemas - schema = processor.schemas["customer"] + assert "customer" in processor.schemas["default"] + schema = processor.schemas["default"]["customer"] assert schema.name == "customer" assert len(schema.fields) == 3 @@ -147,24 +146,26 @@ class TestRowsGraphQLQueryLogic: status_field = next(f for f in schema.fields if f.name == "status") assert status_field.enum_values == ["active", "inactive"] - # Verify schema builder was called - processor.schema_builder.add_schema.assert_called_once() - processor.schema_builder.build.assert_called_once() + # Verify per-workspace schema builder was created and graphql schema built + assert "default" in processor.schema_builders + assert "default" in processor.graphql_schemas @pytest.mark.asyncio async def test_graphql_context_handling(self): """Test GraphQL execution context setup""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Mock schema execution mock_result = MagicMock() mock_result.data = {"customers": [{"id": "1", "name": "Test"}]} mock_result.errors = None - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, @@ -173,8 +174,8 @@ class TestRowsGraphQLQueryLogic: ) # Verify schema.execute was called with correct context - processor.graphql_schema.execute.assert_called_once() - call_args = processor.graphql_schema.execute.call_args + graphql_schema.execute.assert_called_once() + call_args = graphql_schema.execute.call_args # Verify context was passed context = call_args[1]['context_value'] @@ -190,7 +191,8 @@ class TestRowsGraphQLQueryLogic: async def test_error_handling_graphql_errors(self): """Test GraphQL error handling and conversion""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Create a simple object to simulate GraphQL error @@ -212,9 +214,10 @@ class TestRowsGraphQLQueryLogic: mock_result = MagicMock() mock_result.data = None mock_result.errors = [mock_error] - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { invalid_field } }', variables={}, operation_name=None, @@ -259,6 +262,7 @@ class TestRowsGraphQLQueryLogic: # Mock flow mock_flow = MagicMock() + mock_flow.workspace = "default" mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow @@ -267,6 +271,7 @@ class TestRowsGraphQLQueryLogic: # Verify query was executed processor.execute_graphql_query.assert_called_once_with( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, diff --git a/tests/unit/test_retrieval/test_nlp_query.py b/tests/unit/test_retrieval/test_nlp_query.py index 1fd35c2e..cc285aea 100644 --- a/tests/unit/test_retrieval/test_nlp_query.py +++ b/tests/unit/test_retrieval/test_nlp_query.py @@ -286,11 +286,11 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - assert "test_schema" in processor.schemas - schema = processor.schemas["test_schema"] + assert "test_schema" in processor.schemas["default"] + schema = processor.schemas["default"]["test_schema"] assert schema.name == "test_schema" assert schema.description == "Test schema" assert len(schema.fields) == 2 @@ -308,10 +308,10 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - bad schema should be ignored - assert "bad_schema" not in processor.schemas + assert "bad_schema" not in processor.schemas.get("default", {}) def test_processor_initialization(self, mock_pulsar_client): """Test processor initialization with correct specifications""" diff --git a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py index 8ce1b97e..45ba9fda 100644 --- a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py +++ b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py @@ -101,7 +101,7 @@ def service(mock_schemas): taskgroup=MagicMock(), id="test-processor" ) - service.schemas = mock_schemas + service.schemas = {"default": dict(mock_schemas)} return service @@ -109,6 +109,7 @@ def service(mock_schemas): def mock_flow(): """Create mock flow with prompt service""" flow = MagicMock() + flow.workspace = "default" prompt_request_flow = AsyncMock() flow.return_value.request = prompt_request_flow return flow, prompt_request_flow diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index ccf193aa..b252c958 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + class TestRowsCassandraStorageLogic: """Test business logic for unified table implementation""" @@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic: } # Process configuration - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Verify schema was loaded - assert "customer_records" in processor.schemas - schema = processor.schemas["customer_records"] + assert "customer_records" in processor.schemas["default"] + schema = processor.schemas["default"]["customer_records"] assert schema.name == "customer_records" assert len(schema.fields) == 3 @@ -165,14 +176,16 @@ class TestRowsCassandraStorageLogic: """Test that row processing stores data as map""" processor = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="value", type="string", size=100) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="value", type="string", size=100) + ] + ) + } } processor.tables_initialized = {"test_user"} processor.registered_partitions = set() @@ -205,7 +218,7 @@ class TestRowsCassandraStorageLogic: msg.value.return_value = test_obj # Process object - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert was executed mock_async_execute.assert_called() @@ -230,14 +243,16 @@ class TestRowsCassandraStorageLogic: """Test that row is written once per indexed field""" processor = MagicMock() processor.schemas = { - "multi_index_schema": RowSchema( - name="multi_index_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True), - Field(name="status", type="string", indexed=True) - ] - ) + "default": { + "multi_index_schema": RowSchema( + name="multi_index_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="status", type="string", indexed=True) + ] + ) + } } processor.tables_initialized = {"test_user"} processor.registered_partitions = set() @@ -267,7 +282,7 @@ class TestRowsCassandraStorageLogic: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per indexed field: id, category, status) assert mock_async_execute.call_count == 3 @@ -290,13 +305,15 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "batch_schema": RowSchema( - name="batch_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string") - ] - ) + "default": { + "batch_schema": RowSchema( + name="batch_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string") + ] + ) + } } processor.tables_initialized = {"test_user"} processor.registered_partitions = set() @@ -331,7 +348,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per row, one index per row since only primary key) assert mock_async_execute.call_count == 3 @@ -349,10 +366,12 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of empty batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "empty_schema": RowSchema( - name="empty_schema", - fields=[Field(name="id", type="string", primary=True)] - ) + "default": { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", primary=True)] + ) + } } processor.tables_initialized = {"test_user"} processor.registered_partitions = set() @@ -381,7 +400,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = empty_batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify no insert calls for empty batch processor.session.execute.assert_not_called() @@ -446,19 +465,21 @@ class TestPartitionRegistration: processor.registered_partitions = set() processor.session = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True) + ] + ) + } } processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should have 2 inserts (one per index: id, category) assert processor.session.execute.call_count == 2 @@ -473,7 +494,7 @@ class TestPartitionRegistration: processor.session = MagicMock() processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should not execute any CQL since already registered processor.session.execute.assert_not_called() diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index dbdce0a8..d5f1bac5 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -50,7 +50,7 @@ class Api: token: Optional bearer token for authentication """ - def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None): + def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None, workspace: str = "default"): """ Initialize the TrustGraph API client. @@ -82,6 +82,7 @@ class Api: self.timeout = timeout self.token = token + self.workspace = workspace # Lazy initialization for new clients self._socket_client = None @@ -137,7 +138,7 @@ class Api: config.put([ConfigValue(type="llm", key="model", value="gpt-4")]) ``` """ - return Config(api=self) + return Config(api=self, workspace=self.workspace) def knowledge(self): """ @@ -191,6 +192,12 @@ class Api: if self.token: headers["Authorization"] = f"Bearer {self.token}" + # Ensure every REST request carries the workspace so services can + # scope their behaviour. Callers that already set workspace in the + # payload (e.g. Library client) take precedence. + if isinstance(request, dict) and "workspace" not in request: + request = {**request, "workspace": self.workspace} + # Invoke the API, input is passed as JSON resp = requests.post(url, json=request, timeout=self.timeout, headers=headers) @@ -297,7 +304,10 @@ class Api: from . socket_client import SocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._socket_client = SocketClient(base_url, self.timeout, self.token) + self._socket_client = SocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._socket_client def bulk(self): @@ -417,7 +427,10 @@ class Api: from . async_socket_client import AsyncSocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token) + self._async_socket_client = AsyncSocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._async_socket_client def async_bulk(self): diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 6e5064ab..da581f73 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -22,10 +22,14 @@ class AsyncSocketClient: Or call connect()/aclose() manually. """ - def __init__(self, url: str, timeout: int, token: Optional[str]): + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ): self.url = self._convert_to_ws_url(url) self.timeout = timeout self.token = token + self.workspace = workspace self._request_counter = 0 self._socket = None self._connect_cm = None @@ -117,6 +121,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -149,6 +154,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py index 414d07db..11cd2843 100644 --- a/trustgraph-base/trustgraph/api/collection.py +++ b/trustgraph-base/trustgraph/api/collection.py @@ -2,11 +2,9 @@ TrustGraph Collection Management This module provides interfaces for managing data collections in TrustGraph. -Collections provide logical grouping and isolation for documents and knowledge -graph data. +Collections provide logical grouping within a workspace. """ -import datetime import logging from . types import CollectionMetadata @@ -18,10 +16,9 @@ class Collection: """ Collection management client. - Provides methods for managing data collections, including listing, - updating metadata, and deleting collections. Collections organize - documents and knowledge graph data into logical groupings for - isolation and access control. + Provides methods for managing data collections within the configured + workspace, including listing, updating metadata, and deleting + collections. """ def __init__(self, api): @@ -45,45 +42,20 @@ class Collection: """ return self.api.request(f"collection-management", request) - def list_collections(self, user, tag_filter=None): + def list_collections(self, tag_filter=None): """ - List all collections for a user. - - Retrieves metadata for all collections owned by the specified user, - with optional filtering by tags. + List all collections in this workspace. Args: - user: User identifier - tag_filter: Optional list of tags to filter collections (default: None) + tag_filter: Optional list of tags to filter collections Returns: list[CollectionMetadata]: List of collection metadata objects - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection = api.collection() - - # List all collections - all_colls = collection.list_collections(user="trustgraph") - for coll in all_colls: - print(f"{coll.collection}: {coll.name}") - print(f" Description: {coll.description}") - print(f" Tags: {', '.join(coll.tags)}") - - # List collections with specific tags - research_colls = collection.list_collections( - user="trustgraph", - tag_filter=["research", "published"] - ) - ``` """ input = { "operation": "list-collections", - "user": user, + "workspace": self.api.workspace, } if tag_filter: @@ -92,7 +64,6 @@ class Collection: object = self.request(input) try: - # Handle case where collections might be None or missing if object is None or "collections" not in object: return [] @@ -102,7 +73,6 @@ class Collection: return [ CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -114,15 +84,11 @@ class Collection: logger.error("Failed to parse collection list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_collection(self, user, collection, name=None, description=None, tags=None): + def update_collection(self, collection, name=None, description=None, tags=None): """ Update collection metadata. - Updates the name, description, and/or tags for an existing collection. - Only provided fields are updated; others remain unchanged. - Args: - user: User identifier collection: Collection identifier name: New collection name (optional) description: New collection description (optional) @@ -130,35 +96,11 @@ class Collection: Returns: CollectionMetadata: Updated collection metadata, or None if not found - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection_api = api.collection() - - # Update collection metadata - updated = collection_api.update_collection( - user="trustgraph", - collection="default", - name="Default Collection", - description="Main data collection for general use", - tags=["default", "production"] - ) - - # Update only specific fields - updated = collection_api.update_collection( - user="trustgraph", - collection="research", - description="Updated description" - ) - ``` """ input = { "operation": "update-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } @@ -175,7 +117,6 @@ class Collection: if "collections" in object and object["collections"]: v = object["collections"][0] return CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -186,37 +127,23 @@ class Collection: logger.error("Failed to parse collection update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def delete_collection(self, user, collection): + def delete_collection(self, collection): """ Delete a collection. - Removes a collection and all its associated data from the system. - Args: - user: User identifier collection: Collection identifier to delete Returns: dict: Empty response object - - Example: - ```python - collection_api = api.collection() - - # Delete a collection - collection_api.delete_collection( - user="trustgraph", - collection="old-collection" - ) - ``` """ input = { "operation": "delete-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } - object = self.request(input) + self.request(input) - return {} \ No newline at end of file + return {} diff --git a/trustgraph-base/trustgraph/api/config.py b/trustgraph-base/trustgraph/api/config.py index c8c8d5bb..5f17672f 100644 --- a/trustgraph-base/trustgraph/api/config.py +++ b/trustgraph-base/trustgraph/api/config.py @@ -21,14 +21,16 @@ class Config: and list operations. """ - def __init__(self, api): + def __init__(self, api, workspace="default"): """ Initialize Config client. Args: api: Parent Api instance for making requests + workspace: Workspace to scope all config operations to """ self.api = api + self.workspace = workspace def request(self, request): """ @@ -75,9 +77,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "get", + "workspace": self.workspace, "keys": [ { "type": k.type, "key": k.key } for k in keys @@ -123,9 +125,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "put", + "workspace": self.workspace, "values": [ { "type": v.type, "key": v.key, "value": v.value } for v in values @@ -157,9 +159,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "delete", + "workspace": self.workspace, "keys": [ { "type": v.type, "key": v.key } for v in keys @@ -195,9 +197,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "list", + "workspace": self.workspace, "type": type, } @@ -235,9 +237,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "getvalues", + "workspace": self.workspace, "type": type, } @@ -255,6 +257,46 @@ class Config: except: raise ProtocolException(f"Response not formatted correctly") + def get_values_all_workspaces(self, type): + """ + Get all configuration values of a given type across all workspaces. + + Unlike get_values(), this is not scoped to a single workspace — + it returns every entry of the given type in the system. Each + returned ConfigValue includes its workspace field. Used by + shared processors to load type-scoped config at startup. + + Args: + type: Configuration type (e.g. "prompt", "schema") + + Returns: + list[ConfigValue]: Values across all workspaces; each has + its workspace field populated. + + Raises: + ProtocolException: If response format is invalid + """ + + input = { + "operation": "getvalues-all-ws", + "type": type, + } + + object = self.request(input) + + try: + return [ + ConfigValue( + type = v["type"], + key = v["key"], + value = v["value"], + workspace = v.get("workspace", ""), + ) + for v in object["values"] + ] + except Exception: + raise ProtocolException("Response not formatted correctly") + def all(self): """ Get complete configuration and version. @@ -279,9 +321,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { - "operation": "config" + "operation": "config", + "workspace": self.workspace, } object = self.request(input) diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 7ee32dad..947b7a56 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -115,72 +115,32 @@ class Flow: return FlowInstance(api=self, id=id) def list_blueprints(self): - """ - List all available flow blueprints. + """List blueprints in the current workspace.""" - Returns: - list[str]: List of blueprint names - - Example: - ```python - blueprints = api.flow().list_blueprints() - print(blueprints) # ['default', 'custom-flow', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-blueprints", + "workspace": self.api.workspace, } return self.request(request = input)["blueprint-names"] def get_blueprint(self, blueprint_name): - """ - Get a flow blueprint definition by name. + """Get a flow blueprint definition by name.""" - Args: - blueprint_name: Name of the blueprint to retrieve - - Returns: - dict: Blueprint definition as a dictionary - - Example: - ```python - blueprint = api.flow().get_blueprint("default") - print(blueprint) # Blueprint configuration - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } return json.loads(self.request(request = input)["blueprint-definition"]) def put_blueprint(self, blueprint_name, definition): - """ - Create or update a flow blueprint. + """Create or update a flow blueprint.""" - Args: - blueprint_name: Name for the blueprint - definition: Blueprint definition dictionary - - Example: - ```python - definition = { - "services": ["text-completion", "graph-rag"], - "parameters": {"model": "gpt-4"} - } - api.flow().put_blueprint("my-blueprint", definition) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "put-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, "blueprint-definition": json.dumps(definition), } @@ -188,96 +148,43 @@ class Flow: self.request(request = input) def delete_blueprint(self, blueprint_name): - """ - Delete a flow blueprint. + """Delete a flow blueprint.""" - Args: - blueprint_name: Name of the blueprint to delete - - Example: - ```python - api.flow().delete_blueprint("old-blueprint") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "delete-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } self.request(request = input) def list(self): - """ - List all active flow instances. + """List flow instances in the current workspace.""" - Returns: - list[str]: List of flow instance IDs - - Example: - ```python - flows = api.flow().list() - print(flows) # ['default', 'flow-1', 'flow-2', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-flows", + "workspace": self.api.workspace, } return self.request(request = input)["flow-ids"] def get(self, id): - """ - Get the definition of a running flow instance. + """Get the definition of a flow instance.""" - Args: - id: Flow instance ID - - Returns: - dict: Flow instance definition - - Example: - ```python - flow_def = api.flow().get("default") - print(flow_def) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-flow", + "workspace": self.api.workspace, "flow-id": id, } return json.loads(self.request(request = input)["flow"]) def start(self, blueprint_name, id, description, parameters=None): - """ - Start a new flow instance from a blueprint. + """Start a new flow instance from a blueprint.""" - Args: - blueprint_name: Name of the blueprint to instantiate - id: Unique identifier for the flow instance - description: Human-readable description - parameters: Optional parameters dictionary - - Example: - ```python - api.flow().start( - blueprint_name="default", - id="my-flow", - description="My custom flow", - parameters={"model": "gpt-4"} - ) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "start-flow", + "workspace": self.api.workspace, "flow-id": id, "blueprint-name": blueprint_name, "description": description, @@ -289,21 +196,11 @@ class Flow: self.request(request = input) def stop(self, id): - """ - Stop a running flow instance. + """Stop a running flow instance.""" - Args: - id: Flow instance ID to stop - - Example: - ```python - api.flow().stop("my-flow") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "stop-flow", + "workspace": self.api.workspace, "flow-id": id, } @@ -349,6 +246,13 @@ class FlowInstance: Returns: dict: Service response """ + # Inject workspace so the gateway can route to the right + # workspace's flow. If already present, keep the caller's value. + if isinstance(request, dict) and "workspace" not in request: + request = { + "workspace": self.api.api.workspace, + **request, + } return self.api.request(path = f"{self.id}/{path}", request = request) def text_completion(self, system, prompt): diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index 84f98918..c3ec2308 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -63,105 +63,50 @@ class Knowledge: """ return self.api.request(f"knowledge", request) - def list_kg_cores(self, user="trustgraph"): + def list_kg_cores(self): """ - List all available knowledge graph cores. - - Retrieves the IDs of all KG cores available for the specified user. - - Args: - user: User identifier (default: "trustgraph") + List all available knowledge graph cores in this workspace. Returns: list[str]: List of KG core identifiers - - Example: - ```python - knowledge = api.knowledge() - - # List available KG cores - cores = knowledge.list_kg_cores(user="trustgraph") - print(f"Available KG cores: {cores}") - ``` """ - # The input consists of system and prompt strings input = { "operation": "list-kg-cores", - "user": user, + "workspace": self.api.workspace, } return self.request(request = input)["ids"] - def delete_kg_core(self, id, user="trustgraph"): + def delete_kg_core(self, id): """ - Delete a knowledge graph core. - - Removes a KG core from storage. This does not affect currently loaded - cores in flows. + Delete a knowledge graph core in this workspace. Args: id: KG core identifier to delete - user: User identifier (default: "trustgraph") - - Example: - ```python - knowledge = api.knowledge() - - # Delete a KG core - knowledge.delete_kg_core(id="medical-kb-v1", user="trustgraph") - ``` """ - # The input consists of system and prompt strings input = { "operation": "delete-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, } self.request(request = input) - def load_kg_core(self, id, user="trustgraph", flow="default", - collection="default"): + def load_kg_core(self, id, flow="default", collection="default"): """ Load a knowledge graph core into a flow. - Makes a KG core available for use in queries and RAG operations within - the specified flow and collection. - Args: id: KG core identifier to load - user: User identifier (default: "trustgraph") flow: Flow instance to load into (default: "default") collection: Collection to associate with (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Load a medical knowledge base into the default flow - knowledge.load_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default", - collection="medical" - ) - - # Now the flow can use this KG core for RAG queries - flow = api.flow().id("default") - response = flow.graph_rag( - query="What are the symptoms of diabetes?", - user="trustgraph", - collection="medical" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "load-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, "collection": collection, @@ -169,35 +114,18 @@ class Knowledge: self.request(request = input) - def unload_kg_core(self, id, user="trustgraph", flow="default"): + def unload_kg_core(self, id, flow="default"): """ Unload a knowledge graph core from a flow. - Removes a KG core from active use in the specified flow, freeing - resources while keeping the core available in storage. - Args: id: KG core identifier to unload - user: User identifier (default: "trustgraph") flow: Flow instance to unload from (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Unload a KG core when no longer needed - knowledge.unload_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "unload-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, } diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index c66598aa..8865d553 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -94,7 +94,7 @@ class Library: return self.api.request(f"librarian", request) def add_document( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind="text/plain", tags=[], on_progress=None, ): """ @@ -176,7 +176,6 @@ class Library: document=document, id=id, metadata=metadata, - user=user, title=title, comments=comments, kind=kind, @@ -213,6 +212,7 @@ class Library: input = { "operation": "add-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -220,7 +220,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags }, "content": base64.b64encode(document).decode("utf-8"), @@ -229,7 +229,7 @@ class Library: return self.request(input) def _add_document_chunked( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind, tags, on_progress=None, ): """ @@ -245,13 +245,14 @@ class Library: # Begin upload session begin_request = { "operation": "begin-upload", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), "kind": kind, "title": title, "comments": comments, - "user": user, + "workspace": self.api.workspace, "tags": tags, }, "total-size": total_size, @@ -279,10 +280,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } chunk_response = self.request(chunk_request) @@ -298,8 +299,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } complete_response = self.request(complete_request) @@ -314,8 +315,8 @@ class Library: try: abort_request = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } self.request(abort_request) logger.info(f"Aborted failed upload {upload_id}") @@ -323,7 +324,7 @@ class Library: logger.warning(f"Failed to abort upload: {abort_error}") raise - def get_documents(self, user, include_children=False): + def get_documents(self, include_children=False): """ List all documents for a user. @@ -359,7 +360,7 @@ class Library: input = { "operation": "list-documents", - "user": user, + "workspace": self.api.workspace, "include-children": include_children, } @@ -381,7 +382,7 @@ class Library: ) for w in v["metadata"] ], - user = v["user"], + workspace = v.get("workspace", ""), tags = v["tags"], parent_id = v.get("parent-id", ""), document_type = v.get("document-type", "source"), @@ -392,7 +393,7 @@ class Library: logger.error("Failed to parse document list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def get_document(self, user, id): + def get_document(self, id): """ Get metadata for a specific document. @@ -419,7 +420,7 @@ class Library: input = { "operation": "get-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -441,7 +442,7 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"], parent_id = doc.get("parent-id", ""), document_type = doc.get("document-type", "source"), @@ -450,7 +451,7 @@ class Library: logger.error("Failed to parse document response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_document(self, user, id, metadata): + def update_document(self, id, metadata): """ Update document metadata. @@ -490,8 +491,9 @@ class Library: input = { "operation": "update-document", + "workspace": self.api.workspace, "document-metadata": { - "user": user, + "workspace": self.api.workspace, "document-id": id, "time": metadata.time, "title": metadata.title, @@ -526,14 +528,14 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"] ) except Exception as e: logger.error("Failed to parse document update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def remove_document(self, user, id): + def remove_document(self, id): """ Remove a document from the library. @@ -555,7 +557,7 @@ class Library: input = { "operation": "remove-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -565,7 +567,7 @@ class Library: def start_processing( self, id, document_id, flow="default", - user="trustgraph", collection="default", tags=[], + collection="default", tags=[], ): """ Start a document processing workflow. @@ -602,12 +604,13 @@ class Library: input = { "operation": "add-processing", + "workspace": self.api.workspace, "processing-metadata": { "id": id, "document-id": document_id, "time": int(time.time()), "flow": flow, - "user": user, + "workspace": self.api.workspace, "collection": collection, "tags": tags, } @@ -618,7 +621,7 @@ class Library: return {} def stop_processing( - self, id, user="trustgraph", + self, id, ): """ Stop a running document processing job. @@ -641,15 +644,15 @@ class Library: input = { "operation": "remove-processing", + "workspace": self.api.workspace, "processing-id": id, - "user": user, } object = self.request(input) return {} - def get_processings(self, user="trustgraph"): + def get_processings(self): """ List all active document processing jobs. @@ -681,7 +684,7 @@ class Library: input = { "operation": "list-processing", - "user": user, + "workspace": self.api.workspace, } object = self.request(input) @@ -693,7 +696,7 @@ class Library: document_id = v["document-id"], time = datetime.datetime.fromtimestamp(v["time"]), flow = v["flow"], - user = v["user"], + workspace = v.get("workspace", ""), collection = v["collection"], tags = v["tags"], ) @@ -705,7 +708,7 @@ class Library: # Chunked upload management methods - def get_pending_uploads(self, user): + def get_pending_uploads(self): """ List all pending (in-progress) uploads for a user. @@ -731,14 +734,14 @@ class Library: """ input = { "operation": "list-uploads", - "user": user, + "workspace": self.api.workspace, } response = self.request(input) return response.get("upload-sessions", []) - def get_upload_status(self, upload_id, user): + def get_upload_status(self, upload_id): """ Get the status of a specific upload. @@ -774,13 +777,13 @@ class Library: """ input = { "operation": "get-upload-status", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def abort_upload(self, upload_id, user): + def abort_upload(self, upload_id): """ Abort an in-progress upload. @@ -801,13 +804,13 @@ class Library: """ input = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def resume_upload(self, upload_id, document, user, on_progress=None): + def resume_upload(self, upload_id, document, on_progress=None): """ Resume an interrupted upload. @@ -844,7 +847,7 @@ class Library: ``` """ # Get current status - status = self.get_upload_status(upload_id, user) + status = self.get_upload_status(upload_id) if status.get("upload-state") == "expired": raise RuntimeError("Upload session has expired, please start a new upload") @@ -867,10 +870,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } self.request(chunk_request) @@ -886,8 +889,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(complete_request) @@ -895,7 +898,7 @@ class Library: # Child document methods def add_child_document( - self, document, id, parent_id, user, title, comments, + self, document, id, parent_id, title, comments, kind="text/plain", tags=[], metadata=None, ): """ @@ -964,6 +967,7 @@ class Library: input = { "operation": "add-child-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -971,7 +975,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags, "parent-id": parent_id, "document-type": "extracted", @@ -981,7 +985,7 @@ class Library: return self.request(input) - def list_children(self, document_id, user): + def list_children(self, document_id): """ List all child documents for a given parent document. @@ -1006,8 +1010,8 @@ class Library: """ input = { "operation": "list-children", + "workspace": self.api.workspace, "document-id": document_id, - "user": user, } response = self.request(input) @@ -1028,7 +1032,7 @@ class Library: ) for w in v.get("metadata", []) ], - user=v["user"], + workspace=v.get("workspace", ""), tags=v.get("tags", []), parent_id=v.get("parent-id", ""), document_type=v.get("document-type", "source"), @@ -1039,7 +1043,7 @@ class Library: logger.error("Failed to parse children response", exc_info=True) raise ProtocolException("Response not formatted correctly") - def get_document_content(self, user, id): + def get_document_content(self, id): """ Get the content of a document. @@ -1067,7 +1071,7 @@ class Library: """ input = { "operation": "get-document-content", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -1076,7 +1080,7 @@ class Library: return base64.b64decode(content_b64) - def stream_document_to_file(self, user, id, file_path, chunk_size=1024*1024, on_progress=None): + def stream_document_to_file(self, id, file_path, chunk_size=1024*1024, on_progress=None): """ Stream document content to a file. @@ -1116,7 +1120,7 @@ class Library: while True: input = { "operation": "stream-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, "chunk-index": chunk_index, "chunk-size": chunk_size, diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index c590c9b4..aee9d450 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -84,10 +84,14 @@ class SocketClient: for streaming responses. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ) -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace self._request_counter: int = 0 self._lock: Lock = Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -251,6 +255,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -290,6 +295,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -328,6 +334,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index f5987b0e..129f807a 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -45,10 +45,13 @@ class ConfigValue: type: Configuration type/category key: Specific configuration key value: Configuration value as string + workspace: Workspace the value belongs to. Only populated for + responses to getvalues-all-ws; empty otherwise. """ type : str key : str value : str + workspace : str = "" @dataclasses.dataclass class DocumentMetadata: @@ -62,7 +65,7 @@ class DocumentMetadata: title: Document title comments: Additional comments or description metadata: List of RDF triples providing structured metadata - user: User/owner identifier + workspace: Workspace the document belongs to tags: List of tags for categorization parent_id: Parent document ID for child documents (empty for top-level docs) document_type: "source" for uploaded documents, "extracted" for derived content @@ -73,7 +76,7 @@ class DocumentMetadata: title : str comments : str metadata : List[Triple] - user : str + workspace : str tags : List[str] parent_id : str = "" document_type : str = "source" @@ -88,7 +91,7 @@ class ProcessingMetadata: document_id: ID of the document being processed time: Processing start timestamp flow: Flow instance handling the processing - user: User identifier + workspace: Workspace the processing job belongs to collection: Target collection for processed data tags: List of tags for categorization """ @@ -96,7 +99,7 @@ class ProcessingMetadata: document_id : str time : datetime.datetime flow : str - user : str + workspace : str collection : str tags : List[str] @@ -105,17 +108,15 @@ class CollectionMetadata: """ Metadata for a data collection. - Collections provide logical grouping and isolation for documents and - knowledge graph data. + Collections provide logical grouping within a workspace for documents + and knowledge graph data. Attributes: - user: User/owner identifier collection: Collection identifier name: Human-readable collection name description: Collection description tags: List of tags for categorization """ - user : str collection : str name : str description : str diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 9b9328cb..a7ce4961 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -125,21 +125,39 @@ class AsyncProcessor: response_metrics = config_resp_metrics, ) - async def fetch_config(self): - """Fetch full config from config service using a short-lived - request/response client. Returns (config, version) or raises.""" - client = self._create_config_client() - try: - await client.start() - resp = await client.request( - ConfigRequest(operation="config"), - timeout=10, - ) - if resp.error: - raise RuntimeError(f"Config error: {resp.error.message}") - return resp.config, resp.version - finally: - await client.stop() + async def _fetch_type_workspace(self, client, workspace, config_type): + """Fetch config values of a single type within one workspace. + Returns dict of {key: value}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues", + workspace=workspace, + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + return {v.key: v.value for v in resp.values} + + async def _fetch_type_all_workspaces(self, client, config_type): + """Fetch config values of a single type across all workspaces. + Returns dict of {workspace: {key: value}}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + + grouped = {} + for v in resp.values: + ws = grouped.setdefault(v.workspace, {}) + ws[v.key] = v.value + return grouped, resp.version # This is called to start dynamic behaviour. # Implements the subscribe-then-fetch pattern to avoid race conditions. @@ -155,21 +173,51 @@ class AsyncProcessor: # processed by on_config_notify, which does the version check async def fetch_and_apply_config(self): - """Fetch full config from config service and apply to all handlers. - Retries until successful — config service may not be ready yet.""" + """Startup: for each registered handler, fetch config for all its + types across all workspaces and invoke the handler once per + workspace. Retries until successful — config service may not be + ready yet.""" while self.running: try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - logger.info(f"Fetched config version {version}") + version = 0 - self.config_version = version + for entry in self.config_handlers: + handler_types = entry["types"] - # Apply to all handlers (startup = invoke all) - for entry in self.config_handlers: - await entry["handler"](config, version) + # Handlers registered without types get nothing + # at startup (there is no "all types" fetch). + if not handler_types: + continue + + # Group all registered types by workspace: + # {workspace: {type: {key: value}}} + per_ws = {} + for t in handler_types: + type_data, v = \ + await self._fetch_type_all_workspaces( + client, t, + ) + version = max(version, v) + for ws, kv in type_data.items(): + per_ws.setdefault(ws, {})[t] = kv + + # Call the handler once per workspace + for ws, config in per_ws.items(): + await entry["handler"](ws, config, version) + + logger.info( + f"Applied startup config version {version}" + ) + self.config_version = version + + finally: + await client.stop() return @@ -204,8 +252,9 @@ class AsyncProcessor: # Called when a config notify message arrives async def on_config_notify(self, message, consumer, flow): - notify_version = message.value().version - notify_types = set(message.value().types) + v = message.value() + notify_version = v.version + changes = v.changes # dict of type -> [workspaces] # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -215,41 +264,60 @@ class AsyncProcessor: ) return - # Check if any handler cares about the affected types - if notify_types: - any_interested = False - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None or notify_types & handler_types: - any_interested = True - break + notify_types = set(changes.keys()) - if not any_interested: - logger.debug( - f"Ignoring config notify v{notify_version}, " - f"no handlers for types {notify_types}" - ) - self.config_version = notify_version - return + # Filter out handlers that don't care about any of the changed + # types. A handler registered without types never fires on + # notifications (nothing to scope to). + interested = [] + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types and notify_types & handler_types: + interested.append(entry) + + if not interested: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"no handlers for types {notify_types}" + ) + self.config_version = notify_version + return logger.info( - f"Config notify v{notify_version} types={list(notify_types)}, " - f"fetching config..." + f"Config notify v{notify_version} " + f"types={list(notify_types)}, fetching config..." ) - # Fetch full config using short-lived client try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - self.config_version = version + for entry in interested: + handler_types = entry["types"] - # Invoke handlers that care about the affected types - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None: - await entry["handler"](config, version) - elif not notify_types or notify_types & handler_types: - await entry["handler"](config, version) + # Build {workspace: {type: {key: value}}} for types + # this handler cares about, where the workspace was + # affected for that type. + per_ws = {} + for t in handler_types: + if t not in changes: + continue + for ws in changes[t]: + kv = await self._fetch_type_workspace( + client, ws, t, + ) + per_ws.setdefault(ws, {})[t] = kv + + for ws, config in per_ws.items(): + await entry["handler"]( + ws, config, notify_version, + ) + + finally: + await client.stop() + + self.config_version = notify_version except Exception as e: logger.error( diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 4bd78428..3771d78e 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor): await super(ChunkingService, self).start() await self.librarian.start() - async def get_document_text(self, doc): + async def get_document_text(self, doc, workspace): """ Get text content from a TextDocument, fetching from librarian if needed. Args: doc: TextDocument with either inline text or document_id + workspace: Workspace for librarian lookup (from flow.workspace) Returns: str: The document text content @@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor): logger.info(f"Fetching document {doc.document_id} from librarian...") text = await self.librarian.fetch_document_text( document_id=doc.document_id, - user=doc.metadata.user, + workspace=workspace, ) logger.info(f"Fetched {len(text)} characters from librarian") return text diff --git a/trustgraph-base/trustgraph/base/collection_config_handler.py b/trustgraph-base/trustgraph/base/collection_config_handler.py index 8c1af822..4cb91c53 100644 --- a/trustgraph-base/trustgraph/base/collection_config_handler.py +++ b/trustgraph-base/trustgraph/base/collection_config_handler.py @@ -15,114 +15,139 @@ class CollectionConfigHandler: Storage services should: 1. Inherit from this class along with their service base class 2. Call register_config_handler(self.on_collection_config) in __init__ - 3. Implement create_collection(user, collection, metadata) method - 4. Implement delete_collection(user, collection) method + 3. Implement create_collection(workspace, collection, metadata) method + 4. Implement delete_collection(workspace, collection) method """ def __init__(self, **kwargs): - # Track known collections: {(user, collection): metadata_dict} + # Track known collections: {(workspace, collection): metadata_dict} self.known_collections: Dict[tuple, dict] = {} # Pass remaining kwargs up the inheritance chain super().__init__(**kwargs) - async def on_collection_config(self, config: dict, version: int): + async def on_collection_config( + self, workspace: str, config: dict, version: int + ): """ Handle config push messages and extract collection information + for a single workspace. Args: + workspace: Workspace the config applies to config: Configuration dictionary from ConfigPush message version: Configuration version number """ - logger.info(f"Processing collection configuration (version {version})") + logger.info( + f"Processing collection configuration " + f"(version {version}, workspace {workspace})" + ) - # Extract collections from config (treat missing key as empty) + # Extract collections from config (treat missing key as empty). + # Each config key IS the collection name — config is already + # partitioned by workspace, so no workspace prefix is needed + # on the key. collection_config = config.get("collection", {}) # Track which collections we've seen in this config current_collections: Set[tuple] = set() - # Process each collection in the config - for key, value_json in collection_config.items(): + for collection, value_json in collection_config.items(): try: - # Parse user:collection key - if ":" not in key: - logger.warning(f"Invalid collection key format (expected user:collection): {key}") - continue + current_collections.add((workspace, collection)) - user, collection = key.split(":", 1) - current_collections.add((user, collection)) - - # Parse metadata metadata = json.loads(value_json) - # Check if this is a new collection or updated - collection_key = (user, collection) - if collection_key not in self.known_collections: - logger.info(f"New collection detected: {user}/{collection}") - await self.create_collection(user, collection, metadata) - self.known_collections[collection_key] = metadata + key = (workspace, collection) + if key not in self.known_collections: + logger.info( + f"New collection detected: {workspace}/{collection}" + ) + await self.create_collection( + workspace, collection, metadata + ) + self.known_collections[key] = metadata else: - # Collection already exists, update metadata if changed - if self.known_collections[collection_key] != metadata: - logger.info(f"Collection metadata updated: {user}/{collection}") - # Most storage services don't need to do anything for metadata updates - # They just need to know the collection exists - self.known_collections[collection_key] = metadata + if self.known_collections[key] != metadata: + logger.info( + f"Collection metadata updated: " + f"{workspace}/{collection}" + ) + self.known_collections[key] = metadata except Exception as e: - logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True) + logger.error( + f"Error processing collection config for " + f"{workspace}/{collection}: {e}", + exc_info=True, + ) - # Find collections that were deleted (in known but not in current) - deleted_collections = set(self.known_collections.keys()) - current_collections - for user, collection in deleted_collections: - logger.info(f"Collection deleted: {user}/{collection}") + # Find collections for THIS workspace that were deleted (in + # known but not in current). Only compare collections owned by + # this workspace — other workspaces' collections are not + # affected by this config update. + known_for_ws = { + (w, c) for (w, c) in self.known_collections.keys() + if w == workspace + } + deleted_collections = known_for_ws - current_collections + for ws, collection in deleted_collections: + logger.info(f"Collection deleted: {ws}/{collection}") try: - # Remove from known_collections FIRST to immediately reject new writes - # This eliminates race condition with worker threads - del self.known_collections[(user, collection)] - # Physical deletion happens after - worker threads already rejecting writes - await self.delete_collection(user, collection) + # Remove from known_collections FIRST to immediately + # reject new writes + del self.known_collections[(ws, collection)] + await self.delete_collection(ws, collection) except Exception as e: - logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True) - # If physical deletion failed, should we re-add to known_collections? - # For now, keep it removed - collection is logically deleted per config + logger.error( + f"Error deleting collection {ws}/{collection}: {e}", + exc_info=True, + ) - logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}") + logger.debug( + f"Collection config processing complete. " + f"Known collections: {len(self.known_collections)}" + ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection( + self, workspace: str, collection: str, metadata: dict, + ): """ Create a collection in the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID metadata: Collection metadata dictionary """ - raise NotImplementedError("Storage service must implement create_collection method") + raise NotImplementedError( + "Storage service must implement create_collection method" + ) - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """ Delete a collection from the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ - raise NotImplementedError("Storage service must implement delete_collection method") + raise NotImplementedError( + "Storage service must implement delete_collection method" + ) - def collection_exists(self, user: str, collection: str) -> bool: + def collection_exists(self, workspace: str, collection: str) -> bool: """ - Check if a collection is known to exist + Check if a collection is known to exist. Args: - user: User ID + workspace: Workspace ID collection: Collection ID Returns: True if collection exists, False otherwise """ - return (user, collection) in self.known_collections + return (workspace, collection) in self.known_collections diff --git a/trustgraph-base/trustgraph/base/config_client.py b/trustgraph-base/trustgraph/base/config_client.py index c9ec3f9b..504a6d58 100644 --- a/trustgraph-base/trustgraph/base/config_client.py +++ b/trustgraph-base/trustgraph/base/config_client.py @@ -18,10 +18,11 @@ class ConfigClient(RequestResponse): ) return resp - async def get(self, type, key, timeout=CONFIG_TIMEOUT): + async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Get a single config value. Returns the value string or None.""" resp = await self._request( operation="get", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) @@ -29,19 +30,21 @@ class ConfigClient(RequestResponse): return resp.values[0].value return None - async def put(self, type, key, value, timeout=CONFIG_TIMEOUT): + async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT): """Put a single config value.""" await self._request( operation="put", + workspace=workspace, values=[ConfigValue(type=type, key=key, value=value)], timeout=timeout, ) - async def put_many(self, values, timeout=CONFIG_TIMEOUT): - """Put multiple config values in a single request. - values is a list of (type, key, value) tuples.""" + async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT): + """Put multiple config values in a single request within a + single workspace. values is a list of (type, key, value) tuples.""" await self._request( operation="put", + workspace=workspace, values=[ ConfigValue(type=t, key=k, value=v) for t, k, v in values @@ -49,19 +52,21 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def delete(self, type, key, timeout=CONFIG_TIMEOUT): + async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Delete a single config key.""" await self._request( operation="delete", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) - async def delete_many(self, keys, timeout=CONFIG_TIMEOUT): - """Delete multiple config keys in a single request. - keys is a list of (type, key) tuples.""" + async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT): + """Delete multiple config keys in a single request within a + single workspace. keys is a list of (type, key) tuples.""" await self._request( operation="delete", + workspace=workspace, keys=[ ConfigKey(type=t, key=k) for t, k in keys @@ -69,15 +74,26 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def keys(self, type, timeout=CONFIG_TIMEOUT): - """List all keys for a config type.""" + async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT): + """List all keys for a config type within a workspace.""" resp = await self._request( operation="list", + workspace=workspace, type=type, timeout=timeout, ) return resp.directory + async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT): + """Return the set of distinct workspaces with any config of + the given type.""" + resp = await self._request( + operation="getvalues-all-ws", + type=type, + timeout=timeout, + ) + return {v.workspace for v in resp.values if v.workspace} + class ConfigClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py index 023537df..af072cca 100644 --- a/trustgraph-base/trustgraph/base/consumer_spec.py +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -24,7 +24,10 @@ class ConsumerSpec(Spec): flow = flow, backend = processor.pubsub, topic = definition["topics"][self.name], - subscriber = processor.id + "--" + flow.name + "--" + self.name, + subscriber = ( + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.name + ), schema = self.schema, handler = self.handler, metrics = consumer_metrics, diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index d5bf8421..cd9e91b1 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling document embeddings query request {id}...") - docs = await self.query_document_embeddings(request) + docs = await self.query_document_embeddings( + flow.workspace, request, + ) logger.debug("Sending document embeddings query response...") r = DocumentEmbeddingsResponse(chunks=docs, error=None) diff --git a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py index 0c7921db..96b7781f 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py @@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_document_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_document_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py index 9a515bf8..2caad938 100644 --- a/trustgraph-base/trustgraph/base/flow.py +++ b/trustgraph-base/trustgraph/base/flow.py @@ -4,15 +4,16 @@ import asyncio class Flow: """ Runtime representation of a deployed flow process. - + This class maintains internal processor states and orchestrates - lifecycles (start, stop) for inputs (consumers) and parameters + lifecycles (start, stop) for inputs (consumers) and parameters that drive data flowing across linked nodes. """ - def __init__(self, id, flow, processor, defn): + def __init__(self, id, flow, workspace, processor, defn): self.id = id self.name = flow + self.workspace = workspace self.producer = {} diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index 99cb0f53..aa7bf921 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor): ) # Initialise flow information state + # Keyed by (workspace, flow) tuples; each workspace has its own + # set of flow variants for this processor. self.flows = {} # These can be overriden by a derived class: @@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor): def register_specification(self, spec: Any) -> None: self.specifications.append(spec) - # Start processing for a new flow - async def start_flow(self, flow, defn): - self.flows[flow] = Flow(self.id, flow, self, defn) - await self.flows[flow].start() - logger.info(f"Started flow: {flow}") - - # Stop processing for a new flow - async def stop_flow(self, flow): - if flow in self.flows: - await self.flows[flow].stop() - del self.flows[flow] - logger.info(f"Stopped flow: {flow}") + # Start processing for a new flow within a workspace + async def start_flow(self, workspace, flow, defn): + key = (workspace, flow) + self.flows[key] = Flow(self.id, flow, workspace, self, defn) + await self.flows[key].start() + logger.info(f"Started flow: {workspace}/{flow}") - # Event handler - called for a configuration change - async def on_configure_flows(self, config, version): + # Stop processing for a flow within a workspace + async def stop_flow(self, workspace, flow): + key = (workspace, flow) + if key in self.flows: + await self.flows[key].stop() + del self.flows[key] + logger.info(f"Stopped flow: {workspace}/{flow}") - logger.info(f"Got config version {version}") + # Event handler - called for a configuration change for a single + # workspace + async def on_configure_flows(self, workspace, config, version): + + logger.info( + f"Got config version {version} for workspace {workspace}" + ) config_type = f"processor:{self.id}" @@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor): for k, v in config[config_type].items() } else: - logger.debug("No configuration settings for me.") + logger.debug( + f"No configuration settings for me in {workspace}." + ) flow_config = {} - # Get list of flows which should be running and are currently - # running - wanted_flows = flow_config.keys() - # This takes a copy, needed because dict gets modified by stop_flow - current_flows = list(self.flows.keys()) + # Get list of flows which should be running in this workspace, + # and the list currently running in this workspace + wanted_flows = set(flow_config.keys()) + current_flows = { + f for (ws, f) in self.flows.keys() if ws == workspace + } - # Start all the flows which arent currently running - for flow in wanted_flows: - if flow not in current_flows: - await self.start_flow(flow, flow_config[flow]) + # Start all the flows which aren't currently running in this + # workspace + for flow in wanted_flows - current_flows: + await self.start_flow(workspace, flow, flow_config[flow]) - # Stop all the unwanted flows which are due to be stopped - for flow in current_flows: - if flow not in wanted_flows: - await self.stop_flow(flow) + # Stop all the unwanted flows in this workspace + for flow in current_flows - wanted_flows: + await self.stop_flow(workspace, flow) - logger.info("Handled config update") + logger.info(f"Handled config update for workspace {workspace}") # Start threads, just call parent async def start(self): diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index 55c8efa9..cbce810c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling graph embeddings query request {id}...") - entities = await self.query_graph_embeddings(request) + entities = await self.query_graph_embeddings( + flow.workspace, request, + ) logger.debug("Sending graph embeddings query response...") r = GraphEmbeddingsResponse(entities=entities, error=None) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py index 09bbbe6a..10cfe93c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py @@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_graph_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_graph_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py index 5ad97f47..3df244d2 100644 --- a/trustgraph-base/trustgraph/base/librarian_client.py +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -150,7 +150,7 @@ class LibrarianClient: finally: self._streams.pop(request_id, None) - async def fetch_document_content(self, document_id, user, timeout=120): + async def fetch_document_content(self, document_id, workspace, timeout=120): """Fetch document content using streaming. Returns base64-encoded content. Caller is responsible for decoding. @@ -158,7 +158,7 @@ class LibrarianClient: req = LibrarianRequest( operation="stream-document", document_id=document_id, - user=user, + workspace=workspace, ) chunks = await self.stream(req, timeout=timeout) @@ -176,24 +176,24 @@ class LibrarianClient: return base64.b64encode(raw) - async def fetch_document_text(self, document_id, user, timeout=120): + async def fetch_document_text(self, document_id, workspace, timeout=120): """Fetch document content and decode as UTF-8 text.""" content = await self.fetch_document_content( - document_id, user, timeout=timeout, + document_id, workspace, timeout=timeout, ) return base64.b64decode(content).decode("utf-8") - async def fetch_document_metadata(self, document_id, user, timeout=120): + async def fetch_document_metadata(self, document_id, workspace, timeout=120): """Fetch document metadata from the librarian.""" req = LibrarianRequest( operation="get-document-metadata", document_id=document_id, - user=user, + workspace=workspace, ) response = await self.request(req, timeout=timeout) return response.document_metadata - async def save_child_document(self, doc_id, parent_id, user, content, + async def save_child_document(self, doc_id, parent_id, workspace, content, document_type="chunk", title=None, kind="text/plain", timeout=120): """Save a child document to the librarian.""" @@ -202,7 +202,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, parent_id=parent_id, @@ -218,7 +218,7 @@ class LibrarianClient: await self.request(req, timeout=timeout) return doc_id - async def save_document(self, doc_id, user, content, title=None, + async def save_document(self, doc_id, workspace, content, title=None, document_type="answer", kind="text/plain", timeout=120): """Save a document to the librarian.""" @@ -227,7 +227,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, document_type=document_type, @@ -238,7 +238,7 @@ class LibrarianClient: document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content).decode("utf-8"), - user=user, + workspace=workspace, ) await self.request(req, timeout=timeout) diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py index b91c655c..aa934a7f 100644 --- a/trustgraph-base/trustgraph/base/request_response_spec.py +++ b/trustgraph-base/trustgraph/base/request_response_spec.py @@ -133,8 +133,9 @@ class RequestResponseSpec(Spec): # Make subscription names unique, so that all subscribers get # to see all response messages subscription = ( - processor.id + "--" + flow.name + "--" + self.request_name + - "--" + str(uuid.uuid4()) + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.request_name + "--" + + str(uuid.uuid4()) ), consumer_name = flow.id, request_topic = definition["topics"][self.request_name], diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py index bf35f869..80f9b0d5 100644 --- a/trustgraph-base/trustgraph/base/subscriber_spec.py +++ b/trustgraph-base/trustgraph/base/subscriber_spec.py @@ -21,7 +21,7 @@ class SubscriberSpec(Spec): subscriber = Subscriber( backend = processor.pubsub, topic = definition["topics"][self.name], - subscription = flow.id, + subscription = flow.id + "--" + flow.workspace + "--" + flow.name, consumer_name = flow.id, schema = self.schema, metrics = subscriber_metrics, diff --git a/trustgraph-base/trustgraph/base/tool_service.py b/trustgraph-base/trustgraph/base/tool_service.py index 3ff977d1..eeaced6a 100644 --- a/trustgraph-base/trustgraph/base/tool_service.py +++ b/trustgraph-base/trustgraph/base/tool_service.py @@ -64,6 +64,7 @@ class ToolService(FlowProcessor): id = msg.properties()["id"] response = await self.invoke_tool( + flow.workspace, request.name, json.loads(request.parameters) if request.parameters else {}, ) diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index 832ff6f1..5850307c 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor): logger.debug(f"Handling triples query request {id}...") + workspace = flow.workspace + if request.streaming: # Streaming mode: send batches - async for batch, is_final in self.query_triples_stream(request): + async for batch, is_final in self.query_triples_stream( + workspace, request, + ): r = TriplesQueryResponse( triples=batch, error=None, @@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor): logger.debug("Triples query streaming completed") else: # Non-streaming mode: single response - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) logger.debug("Sending triples query response...") r = TriplesQueryResponse(triples=triples, error=None) await flow("response").send(r, properties={"id": id}) @@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def query_triples_stream(self, request): + async def query_triples_stream(self, workspace, request): """ Streaming query - yields (batch, is_final) tuples. Default implementation batches results from query_triples. Override for true streaming from backend. """ - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) batch_size = request.batch_size if request.batch_size > 0 else 20 for i in range(0, len(triples), batch_size): diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py index abd3aab8..7c44fe29 100644 --- a/trustgraph-base/trustgraph/base/triples_store_service.py +++ b/trustgraph-base/trustgraph/base/triples_store_service.py @@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor): request = msg.value() - await self.store_triples(request) + # Workspace is derived from the flow the message arrived on, + # not from fields in the message payload. Topic routing is + # the isolation boundary. + await self.store_triples(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index 78b62688..25c1af94 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -33,6 +33,7 @@ class ConfigClient(BaseClient): subscriber=None, input_queue=None, output_queue=None, + workspace="default", **pubsub_config, ): @@ -51,10 +52,13 @@ class ConfigClient(BaseClient): **pubsub_config, ) + self.workspace = workspace + def get(self, keys, timeout=300): resp = self.call( operation="get", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -78,6 +82,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="list", + workspace=self.workspace, type=type, timeout=timeout ) @@ -88,6 +93,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="getvalues", + workspace=self.workspace, type=type, timeout=timeout ) @@ -101,10 +107,31 @@ class ConfigClient(BaseClient): for v in resp.values ] + def getvalues_all_ws(self, type, timeout=300): + """Fetch all values of a given type across all workspaces. + Returns a list of dicts including a 'workspace' field.""" + + resp = self.call( + operation="getvalues-all-ws", + type=type, + timeout=timeout + ) + + return [ + { + "workspace": v.workspace, + "type": v.type, + "key": v.key, + "value": v.value, + } + for v in resp.values + ] + def delete(self, keys, timeout=300): resp = self.call( operation="delete", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -121,6 +148,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="put", + workspace=self.workspace, values=[ ConfigValue( type = v["type"], @@ -138,6 +166,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="config", + workspace=self.workspace, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 2e39e8c2..cd07bc99 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -9,7 +9,7 @@ class CollectionManagementRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), collection=data.get("collection"), timestamp=data.get("timestamp"), name=data.get("name"), @@ -24,8 +24,8 @@ class CollectionManagementRequestTranslator(MessageTranslator): if obj.operation is not None: result["operation"] = obj.operation - if obj.user is not None: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection is not None: result["collection"] = obj.collection if obj.timestamp is not None: @@ -63,7 +63,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): if "collections" in data: for coll_data in data["collections"]: collections.append(CollectionMetadata( - user=coll_data.get("user"), collection=coll_data.get("collection"), name=coll_data.get("name"), description=coll_data.get("description"), @@ -91,7 +90,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): result["collections"] = [] for coll in obj.collections: result["collections"].append({ - "user": coll.user, "collection": coll.collection, "name": coll.name, "description": coll.description, diff --git a/trustgraph-base/trustgraph/messaging/translators/config.py b/trustgraph-base/trustgraph/messaging/translators/config.py index e166362a..223db6c8 100644 --- a/trustgraph-base/trustgraph/messaging/translators/config.py +++ b/trustgraph-base/trustgraph/messaging/translators/config.py @@ -23,13 +23,15 @@ class ConfigRequestTranslator(MessageTranslator): ConfigValue( type=v["type"], key=v["key"], - value=v["value"] + value=v["value"], + workspace=v.get("workspace", ""), ) for v in data["values"] ] return ConfigRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), keys=keys, type=data.get("type"), values=values @@ -37,10 +39,13 @@ class ConfigRequestTranslator(MessageTranslator): def encode(self, obj: ConfigRequest) -> Dict[str, Any]: result = {} - + if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace + if obj.type is not None: result["type"] = obj.type @@ -56,13 +61,14 @@ class ConfigRequestTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + return result @@ -81,13 +87,14 @@ class ConfigResponseTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + if obj.directory is not None: result["directory"] = obj.directory diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index df2aa3ba..61917321 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -39,7 +39,6 @@ class DocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), data=base64.b64encode(doc).decode("utf-8") @@ -56,8 +55,6 @@ class DocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -79,7 +76,6 @@ class TextDocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), text=text.encode("utf-8") @@ -96,8 +92,6 @@ class TextDocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -115,7 +109,6 @@ class ChunkTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] @@ -132,8 +125,6 @@ class ChunkTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -161,7 +152,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata=Metadata( id=metadata.get("id"), root=metadata.get("root", ""), - user=metadata.get("user", "trustgraph"), collection=metadata.get("collection", "default"), ), chunks=chunks @@ -184,8 +174,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py index 2047475e..07304c18 100644 --- a/trustgraph-base/trustgraph/messaging/translators/flow.py +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -9,18 +9,21 @@ class FlowRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> FlowRequest: return FlowRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), blueprint_name=data.get("blueprint-name"), blueprint_definition=data.get("blueprint-definition"), description=data.get("description"), flow_id=data.get("flow-id"), parameters=data.get("parameters") ) - + def encode(self, obj: FlowRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace if obj.blueprint_name is not None: result["blueprint-name"] = obj.blueprint_name if obj.blueprint_definition is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f819dc9c..83cdbbf4 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -21,7 +21,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["triples"]["metadata"]["id"], root=data["triples"]["metadata"].get("root", ""), - user=data["triples"]["metadata"]["user"], collection=data["triples"]["metadata"]["collection"] ), triples=self.subgraph_translator.decode(data["triples"]["triples"]), @@ -33,7 +32,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["graph-embeddings"]["metadata"]["id"], root=data["graph-embeddings"]["metadata"].get("root", ""), - user=data["graph-embeddings"]["metadata"]["user"], collection=data["graph-embeddings"]["metadata"]["collection"] ), entities=[ @@ -47,7 +45,7 @@ class KnowledgeRequestTranslator(MessageTranslator): return KnowledgeRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), id=data.get("id"), flow=data.get("flow"), collection=data.get("collection"), @@ -60,8 +58,8 @@ class KnowledgeRequestTranslator(MessageTranslator): if obj.operation: result["operation"] = obj.operation - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.id: result["id"] = obj.id if obj.flow: @@ -74,7 +72,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -85,7 +82,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ @@ -122,7 +118,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -136,7 +131,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index 7c77c39c..d528097e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -49,7 +49,7 @@ class LibraryRequestTranslator(MessageTranslator): document_metadata=doc_metadata, processing_metadata=proc_metadata, content=content, - user=data.get("user", ""), + workspace=data.get("workspace", ""), collection=data.get("collection", ""), criteria=criteria, # Chunked upload fields @@ -76,8 +76,8 @@ class LibraryRequestTranslator(MessageTranslator): result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.criteria is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 3e141c19..9da5d5c0 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -19,7 +19,7 @@ class DocumentMetadataTranslator(Translator): title=data.get("title"), comments=data.get("comments"), metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], - user=data.get("user"), + workspace=data.get("workspace"), tags=data.get("tags"), parent_id=data.get("parent-id", ""), document_type=data.get("document-type", "source"), @@ -40,8 +40,8 @@ class DocumentMetadataTranslator(Translator): result["comments"] = obj.comments if obj.metadata is not None: result["metadata"] = self.subgraph_translator.encode(obj.metadata) - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.tags is not None: result["tags"] = obj.tags if obj.parent_id: @@ -61,7 +61,7 @@ class ProcessingMetadataTranslator(Translator): document_id=data.get("document-id"), time=data.get("time"), flow=data.get("flow"), - user=data.get("user"), + workspace=data.get("workspace"), collection=data.get("collection"), tags=data.get("tags") ) @@ -77,8 +77,8 @@ class ProcessingMetadataTranslator(Translator): result["time"] = obj.time if obj.flow: result["flow"] = obj.flow - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.tags is not None: diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index a37a8d62..a307db4f 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -8,6 +8,7 @@ class Metadata: # Root document identifier (set by librarian, preserved through pipeline) root: str = "" - # Collection management - user: str = "" + # Collection the message belongs to. Workspace is NOT carried on the + # message — consumers derive it from flow.workspace (the flow the + # message arrived on), which is the trusted isolation boundary. collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 0c4a9f7c..37969566 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings # <- (error) # list-kg-cores -# -> (user) +# -> (workspace) # <- () # <- (error) @@ -27,8 +27,8 @@ class KnowledgeRequest: # load-kg-core, unload-kg-core operation: str = "" - # list-kg-cores, delete-kg-core, put-kg-core - user: str = "" + # Workspace the cores belong to. Partition / isolation boundary. + workspace: str = "" # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # load-kg-core, unload-kg-core diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index f4b5fc6e..13dd0607 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -13,7 +13,6 @@ from ..core.topic import queue @dataclass class CollectionMetadata: """Collection metadata record""" - user: str = "" collection: str = "" name: str = "" description: str = "" @@ -23,11 +22,17 @@ class CollectionMetadata: @dataclass class CollectionManagementRequest: - """Request for collection management operations""" + """Request for collection management operations. + + Collection-management is a global (non-flow-scoped) service, so the + workspace has to travel on the wire — it's the isolation boundary + for which workspace's collections the request operates on. + """ operation: str = "" # e.g., "delete-collection" - # For 'list-collections' - user: str = "" + # Workspace the collection belongs to. + workspace: str = "" + collection: str = "" timestamp: str = "" # ISO timestamp name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index c08e96d7..3bcbc72c 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -7,12 +7,19 @@ from ..core.primitives import Error ############################################################################ # Config service: -# get(keys) -> (version, values) -# list(type) -> (version, values) -# getvalues(type) -> (version, values) -# put(values) -> () -# delete(keys) -> () -# config() -> (version, config) +# get(workspace, keys) -> (version, values) +# list(workspace, type) -> (version, directory) +# getvalues(workspace, type) -> (version, values) +# getvalues-all-ws(type) -> (version, values with workspace field) +# put(workspace, values) -> () +# delete(workspace, keys) -> () +# config(workspace) -> (version, config) +# +# Most operations are scoped to a workspace. The workspace field on the +# request identifies which workspace's config to read or modify. +# getvalues-all-ws returns values across all workspaces for a single +# type — used by shared processors to load type-scoped config at startup. + @dataclass class ConfigKey: type: str = "" @@ -23,16 +30,24 @@ class ConfigValue: type: str = "" key: str = "" value: str = "" + # Populated by getvalues-all-ws responses so callers can identify + # which workspace each value belongs to. Empty otherwise. + workspace: str = "" -# Prompt services, abstract the prompt generation @dataclass class ConfigRequest: - operation: str = "" # get, list, getvalues, delete, put, config + # Operations: get, list, getvalues, getvalues-all-ws, delete, put, + # config + operation: str = "" + + # Workspace scope — required on all operations except + # getvalues-all-ws which spans all workspaces. + workspace: str = "" # get, delete keys: list[ConfigKey] = field(default_factory=list) - # list, getvalues + # list, getvalues, getvalues-all-ws type: str = "" # put @@ -58,7 +73,12 @@ class ConfigResponse: @dataclass class ConfigPush: version: int = 0 - types: list[str] = field(default_factory=list) + + # Dict of config type -> list of affected workspaces. + # Handlers look up their registered type and get the list of + # workspaces that need refreshing. + # e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]} + changes: dict[str, list[str]] = field(default_factory=dict) config_request_queue = queue('config', cls='request') config_response_queue = queue('config', cls='response') diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index 0d497dd7..586c160d 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -17,12 +17,14 @@ from ..core.primitives import Error # start_flow(flowid, blueprintname) -> () # stop_flow(flowid) -> () -# Prompt services, abstract the prompt generation @dataclass class FlowRequest: operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint # list-flows, get-flow, start-flow, stop-flow + # Workspace scope — all operations act within this workspace + workspace: str = "" + # get_blueprint, put_blueprint, delete_blueprint, start_flow blueprint_name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index f5d4592c..961b47dc 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -43,12 +43,12 @@ from ..core.metadata import Metadata # <- (error) # list-documents -# -> (user, collection?) +# -> (workspace, collection?) # <- (document_metadata[]) # <- (error) # list-processing -# -> (user, collection?) +# -> (workspace, collection?) # <- (processing_metadata[]) # <- (error) @@ -78,7 +78,7 @@ from ..core.metadata import Metadata # <- (error) # list-uploads -# -> (user) +# -> (workspace) # <- (uploads[]) # <- (error) @@ -90,7 +90,7 @@ class DocumentMetadata: title: str = "" comments: str = "" metadata: list[Triple] = field(default_factory=list) - user: str = "" + workspace: str = "" tags: list[str] = field(default_factory=list) # Child document support parent_id: str = "" # Empty for top-level docs, set for children @@ -107,7 +107,7 @@ class ProcessingMetadata: document_id: str = "" time: int = 0 flow: str = "" - user: str = "" + workspace: str = "" collection: str = "" tags: list[str] = field(default_factory=list) @@ -162,8 +162,8 @@ class LibrarianRequest: # add-document, upload-chunk content: bytes = b"" - # list-documents, list-processing, list-uploads - user: str = "" + # Workspace scopes every library operation. + workspace: str = "" # list-documents?, list-processing? collection: str = "" diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 2d65461b..f0c8d571 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 0151fef4..a5738449 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "requests", "pulsar-client", "aiohttp", @@ -95,6 +95,8 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main" tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" tg-delete-config-item = "trustgraph.cli.delete_config_item:main" +tg-export-workspace-config = "trustgraph.cli.export_workspace_config:main" +tg-import-workspace-config = "trustgraph.cli.import_workspace_config:main" tg-list-collections = "trustgraph.cli.list_collections:main" tg-set-collection = "trustgraph.cli.set_collection:main" tg-delete-collection = "trustgraph.cli.delete_collection:main" diff --git a/trustgraph-cli/trustgraph/cli/add_library_document.py b/trustgraph-cli/trustgraph/cli/add_library_document.py index 3273e63d..8d08d11a 100644 --- a/trustgraph-cli/trustgraph/cli/add_library_document.py +++ b/trustgraph-cli/trustgraph/cli/add_library_document.py @@ -15,17 +15,17 @@ from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import DigitalDocument default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class Loader: def __init__( - self, id, url, user, metadata, title, comments, kind, tags - ): + self, id, url, metadata, title, comments, kind, tags + , token=None, workspace="default"): - self.api = Api(url).library() + self.api = Api(url, token=token, workspace=workspace).library() - self.user = user self.metadata = metadata self.title = title self.comments = comments @@ -55,13 +55,13 @@ class Loader: else: id = hash(data) id = to_uri(PREF_DOC, id) - + self.metadata.id = id self.api.add_document( - document=data, id=id, metadata=self.metadata, - user=self.user, kind=self.kind, title=self.title, + document=data, id=id, metadata=self.metadata, + kind=self.kind, title=self.title, comments=self.comments, tags=self.tags ) @@ -83,11 +83,16 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -186,12 +191,13 @@ def main(): p = Loader( id=args.identifier, url=args.url, - user=args.user, metadata=document, title=args.name, comments=args.description, kind=args.kind, tags=args.tags, + token=args.token, + workspace=args.workspace, ) p.load(args.files) diff --git a/trustgraph-cli/trustgraph/cli/delete_collection.py b/trustgraph-cli/trustgraph/cli/delete_collection.py index 3e19ac09..aedd801a 100644 --- a/trustgraph-cli/trustgraph/cli/delete_collection.py +++ b/trustgraph-cli/trustgraph/cli/delete_collection.py @@ -7,9 +7,11 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_collection(url, user, collection, confirm): + +def delete_collection(url, collection, confirm, token=None, workspace="default"): if not confirm: response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ") @@ -17,9 +19,9 @@ def delete_collection(url, user, collection, confirm): print("Operation cancelled.") return - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - api.delete_collection(user=user, collection=collection) + api.delete_collection(collection=collection) print(f"Collection '{collection}' deleted successfully.") @@ -41,27 +43,34 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-y', '--yes', action='store_true', help='Skip confirmation prompt' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_collection( url = args.api_url, - user = args.user, collection = args.collection, - confirm = args.yes + confirm = args.yes, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -69,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/delete_config_item.py b/trustgraph-cli/trustgraph/cli/delete_config_item.py index cf4cba93..801c2a99 100644 --- a/trustgraph-cli/trustgraph/cli/delete_config_item.py +++ b/trustgraph-cli/trustgraph/cli/delete_config_item.py @@ -9,10 +9,11 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_config_item(url, config_type, key, token=None): +def delete_config_item(url, config_type, key, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) api.delete([config_key]) @@ -50,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -59,6 +66,8 @@ def main(): config_type=args.type, key=args.key, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py index 9ff8aeba..62140f0e 100644 --- a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py @@ -9,10 +9,13 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_flow_blueprint(url, blueprint_name): +def delete_flow_blueprint(url, blueprint_name, token=None, + workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() blueprint_names = api.delete_blueprint(blueprint_name) @@ -29,6 +32,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -41,6 +56,8 @@ def main(): delete_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_kg_core.py b/trustgraph-cli/trustgraph/cli/delete_kg_core.py index 81f95e45..0e0753e0 100644 --- a/trustgraph-cli/trustgraph/cli/delete_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/delete_kg_core.py @@ -1,20 +1,20 @@ """ -Deletes a flow class +Deletes a knowledge core """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_kg_core(url, user, id): +def delete_kg_core(url, id, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.delete_kg_core(user = user, id = id) + api.delete_kg_core(id=id) def main(): @@ -29,26 +29,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, help=f'Knowledge core ID', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_kg_core( url=args.api_url, - user=args.user, id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py index a3ae7e77..eed9ed21 100644 --- a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py @@ -10,12 +10,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_mcp_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool exists first try: @@ -73,6 +77,18 @@ def main(): help='MCP tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -81,8 +97,10 @@ def main(): raise RuntimeError("Must specify --id for MCP tool to delete") delete_mcp_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_tool.py b/trustgraph-cli/trustgraph/cli/delete_tool.py index 961c9aa8..50f43fdd 100644 --- a/trustgraph-cli/trustgraph/cli/delete_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_tool.py @@ -12,12 +12,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool configuration exists try: @@ -78,6 +82,18 @@ def main(): help='Tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -86,8 +102,10 @@ def main(): raise RuntimeError("Must specify --id for tool to delete") delete_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/export_workspace_config.py b/trustgraph-cli/trustgraph/cli/export_workspace_config.py new file mode 100644 index 00000000..feef97de --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/export_workspace_config.py @@ -0,0 +1,114 @@ +""" +Exports a curated subset of a workspace's configuration to a JSON file +for later reload into another workspace (useful for cloning test setups). + +The subset covers the config types that define workspace behaviour: +mcp-tool, tool, flow-blueprint, token-cost, agent-pattern, +agent-task-type, parameter-type, interface-description, prompt. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +EXPORT_TYPES = [ + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +] + + +def export_workspace_config(url, workspace, output, token=None): + + api = Api(url, token=token, workspace=workspace).config() + + config, version = api.all() + + subset = {} + for t in EXPORT_TYPES: + if t in config: + subset[t] = config[t] + + payload = { + "source_workspace": workspace, + "source_version": version, + "config": subset, + } + + if output == "-": + json.dump(payload, sys.stdout, indent=2) + sys.stdout.write("\n") + else: + with open(output, "w") as f: + json.dump(payload, f, indent=2) + + total = sum(len(v) for v in subset.values()) + print( + f"Exported {total} items across {len(subset)} types " + f"from workspace '{workspace}' (version {version}).", + file=sys.stderr, + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-export-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Source workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help='Output JSON file path (use "-" for stdout)', + ) + + args = parser.parse_args() + + try: + + export_workspace_config( + url=args.api_url, + workspace=args.workspace, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_config_item.py b/trustgraph-cli/trustgraph/cli/get_config_item.py index c2421e94..028cc064 100644 --- a/trustgraph-cli/trustgraph/cli/get_config_item.py +++ b/trustgraph-cli/trustgraph/cli/get_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_config_item(url, config_type, key, format_type, token=None): +def get_config_item(url, config_type, key, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) values = api.get([config_key]) @@ -66,6 +68,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -76,6 +84,7 @@ def main(): key=args.key, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py index 3d70f37d..62fa7ca2 100644 --- a/trustgraph-cli/trustgraph/cli/get_document_content.py +++ b/trustgraph-cli/trustgraph/cli/get_document_content.py @@ -9,21 +9,19 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_content(url, user, document_id, output_file, token=None): +def get_content(url, document_id, output_file, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - content = api.get_document_content(user=user, id=document_id) + content = api.get_document_content(id=document_id) if output_file: with open(output_file, 'wb') as f: f.write(content) print(f"Written {len(content)} bytes to {output_file}") else: - # Write to stdout - # Try to decode as text, fall back to binary info try: text = content.decode('utf-8') print(text) @@ -51,9 +49,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -73,10 +71,10 @@ def main(): get_content( url=args.api_url, - user=args.user, document_id=args.document_id, output_file=args.output, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py index 817b8f47..56d43a7c 100644 --- a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py @@ -9,10 +9,12 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_flow_blueprint(url, blueprint_name): +def get_flow_blueprint(url, blueprint_name, token=None, workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() cls = api.get_blueprint(blueprint_name) @@ -31,6 +33,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -44,6 +58,8 @@ def main(): get_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index b75f7155..8bee4115 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,7 +5,6 @@ to a local file in msgpack format. import argparse import os -import textwrap import uuid import asyncio import json @@ -13,17 +12,16 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def write_triple(f, data): msg = ( "t", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -36,9 +34,8 @@ def write_ge(f, data): "ge", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -52,7 +49,7 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, user, id, output, token=None): +async def fetch(url, workspace, id, output, token=None): if not url.endswith("/"): url += "/" @@ -68,10 +65,11 @@ async def fetch(url, user, id, output, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, } }) @@ -124,10 +122,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -154,11 +153,11 @@ def main(): asyncio.run( fetch( - url = args.url, - user = args.user, - id = args.id, - output = args.output, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) ) @@ -167,4 +166,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py index 840f8574..4d4a94b3 100644 --- a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py +++ b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def term_to_rdflib(term): @@ -58,9 +58,10 @@ def term_to_rdflib(term): return rdflib.term.Literal(str(term)) -def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, + token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) g = rdflib.Graph() @@ -68,7 +69,7 @@ def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): try: for batch in flow.triples_query_stream( s=None, p=None, o=None, - user=user, collection=collection, + collection=collection, limit=limit, batch_size=batch_size, ): @@ -108,12 +109,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -126,6 +121,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -147,11 +148,11 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/import_workspace_config.py b/trustgraph-cli/trustgraph/cli/import_workspace_config.py new file mode 100644 index 00000000..3fe3be97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/import_workspace_config.py @@ -0,0 +1,143 @@ +""" +Imports a workspace-config dump produced by tg-export-workspace-config +into a target workspace. Writes mcp-tool, tool, flow-blueprint, +token-cost, agent-pattern, agent-task-type, parameter-type, +interface-description and prompt items verbatim. + +Existing items with the same (type, key) are overwritten. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api +from trustgraph.api.types import ConfigValue + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +IMPORT_TYPES = { + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +} + + +def import_workspace_config(url, workspace, input_path, token=None, + dry_run=False): + + if input_path == "-": + payload = json.load(sys.stdin) + else: + with open(input_path, "r") as f: + payload = json.load(f) + + # Accept both the wrapped export format and a bare {type: {key: value}} + # dict, so hand-written files are also loadable. + if isinstance(payload, dict) and "config" in payload \ + and isinstance(payload["config"], dict): + config = payload["config"] + source = payload.get("source_workspace") + else: + config = payload + source = None + + skipped_types = set(config.keys()) - IMPORT_TYPES + if skipped_types: + print( + f"Ignoring unsupported types: {sorted(skipped_types)}", + file=sys.stderr, + ) + + values = [] + for t in IMPORT_TYPES: + items = config.get(t, {}) + for key, value in items.items(): + values.append(ConfigValue(type=t, key=key, value=value)) + + if not values: + print("Nothing to import.", file=sys.stderr) + return + + if dry_run: + print( + f"[dry-run] would import {len(values)} items into " + f"workspace '{workspace}'" + + (f" (from '{source}')" if source else "") + ) + return + + api = Api(url, token=token, workspace=workspace).config() + api.put(values) + + print( + f"Imported {len(values)} items into workspace '{workspace}'" + + (f" (from '{source}')." if source else "."), + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-import-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Target workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help='Input JSON file path (use "-" for stdin)', + ) + + parser.add_argument( + '--dry-run', + action='store_true', + help='Parse and validate the input without writing anything', + ) + + args = parser.parse_args() + + try: + + import_workspace_config( + url=args.api_url, + workspace=args.workspace, + input_path=args.input, + token=args.token, + dry_run=args.dry_run, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index 18c240ef..d984f925 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -69,10 +69,11 @@ def ensure_namespace(url, tenant, namespace, config): print(f"Namespace {tenant}/{namespace} created.", flush=True) -def ensure_config(config, **pubsub_config): +def ensure_config(config, workspace="default", **pubsub_config): cli = ConfigClient( subscriber=subscriber, + workspace=workspace, **pubsub_config, ) @@ -147,7 +148,8 @@ def init_pulsar(pulsar_admin_url, tenant): }) -def push_config(config_json, config_file, **pubsub_config): +def push_config(config_json, config_file, workspace="default", + **pubsub_config): """Push initial config if provided.""" if config_json is not None: @@ -160,7 +162,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) elif config_file is not None: @@ -172,7 +174,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) else: print("No config to update.", flush=True) @@ -207,6 +209,12 @@ def main(): help=f'Tenant (default: tg)', ) + parser.add_argument( + '-w', '--workspace', + default="default", + help=f'Workspace (default: default)', + ) + add_pubsub_args(parser) args = parser.parse_args() @@ -216,7 +224,10 @@ def main(): # Extract pubsub config from args pubsub_config = { k: v for k, v in vars(args).items() - if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant') + if k not in ( + 'pulsar_admin_url', 'config', 'config_file', 'tenant', + 'workspace', + ) } while True: @@ -241,6 +252,7 @@ def main(): # Push config (works with any backend) push_config( args.config, args.config_file, + workspace=args.workspace, **pubsub_config, ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index b379c2df..7490f868 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -26,7 +26,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Outputter: @@ -119,7 +119,7 @@ def question_explainable( state=None, group=None, verbose=False, token=None, debug=False ): """Execute agent with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -296,7 +296,7 @@ def question( print() # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -418,6 +418,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -430,12 +436,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -502,7 +502,7 @@ def main(): url = args.url, flow_id = args.flow_id, question = args.question, - user = args.user, + user = "", collection = args.collection, plan = args.plan, state = args.state, diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py index 43bcc985..ed851dff 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call document embeddings query service result = flow.document_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -59,15 +59,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -97,10 +97,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index d566f51d..9a2a1118 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -18,16 +18,17 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 def question_explainable( - url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False + url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False, + workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -100,7 +101,7 @@ def question_explainable( def question( url, flow_id, question_text, user, collection, doc_limit, streaming=True, token=None, explainable=False, debug=False, - show_usage=False + show_usage=False, workspace="default", ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -112,12 +113,13 @@ def question( collection=collection, doc_limit=doc_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -189,6 +191,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -201,12 +209,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -252,7 +254,7 @@ def main(): url=args.url, flow_id=args.flow_id, question_text=args.question, - user=args.user, + user="", collection=args.collection, doc_limit=args.doc_limit, streaming=not args.no_streaming, @@ -260,6 +262,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py index 699a85cf..62eaa039 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, texts, token=None): +def query(url, flow_id, texts, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -51,6 +52,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -72,6 +79,8 @@ def main(): flow_id=args.flow_id, texts=args.texts, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py index 5b0f4c67..c7237c06 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call graph embeddings query service result = flow.graph_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -69,15 +69,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -107,10 +107,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index c9efe54d..5fed9496 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -22,7 +22,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_entity_limit = 50 default_triple_limit = 30 @@ -641,10 +641,10 @@ async def _question_explainable( def _question_explainable_api( url, flow_id, question_text, user, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, - edge_limit=25, token=None, debug=False + edge_limit=25, token=None, debug=False, workspace="default", ): """Execute graph RAG with explainability using the new API classes.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -753,7 +753,8 @@ def question( url, flow_id, question, user, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, edge_limit=25, streaming=True, token=None, - explainable=False, debug=False, show_usage=False + explainable=False, debug=False, show_usage=False, + workspace="default", ): # Explainable mode uses the API to capture and process provenance events @@ -771,12 +772,13 @@ def question( edge_score_limit=edge_score_limit, edge_limit=edge_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -857,6 +859,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -869,12 +877,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -955,7 +957,7 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, + user="", collection=args.collection, entity_limit=args.entity_limit, triple_limit=args.triple_limit, @@ -968,6 +970,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index 3bf521f6..2006e9e8 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -9,12 +9,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, system, prompt, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -74,6 +75,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( 'system', nargs=1, @@ -116,6 +123,7 @@ def main(): streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py index c5700c5c..32c20768 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py @@ -11,10 +11,12 @@ import json from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, name, parameters): +def query(url, flow_id, name, parameters, token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.mcp_tool(name=name, parameters=parameters) @@ -36,6 +38,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -68,6 +82,8 @@ def main(): flow_id = args.flow_id, name = args.name, parameters = parameters, + token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py index 8b01187c..332531db 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py @@ -10,9 +10,11 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def nlp_query(url, flow_id, question, max_results, output_format='json'): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def nlp_query(url, flow_id, question, max_results, output_format='json', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.nlp_query( question=question, @@ -63,6 +65,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -100,6 +113,11 @@ def main(): question=args.question, max_results=args.max_results, output_format=args.format, + + token = args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index 86f7a024..ed47df90 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -14,12 +14,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, template_id, variables, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -80,6 +81,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -135,6 +142,7 @@ specified multiple times''', streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py index 7393b4c3..8244ae99 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None): +def query(url, flow_id, query_text, schema_name, collection, index_name, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -22,7 +23,6 @@ def query(url, flow_id, query_text, schema_name, user, collection, index_name, l result = flow.row_embeddings_query( text=query_text, schema_name=schema_name, - user=user, collection=collection, index_name=index_name, limit=limit @@ -60,15 +60,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -111,11 +111,11 @@ def main(): flow_id=args.flow_id, query_text=args.query[0], schema_name=args.schema_name, - user=args.user, collection=args.collection, index_name=args.index_name, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py index 962f353c..46fba4d7 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py @@ -12,10 +12,11 @@ from trustgraph.api import Api from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' -def format_output(data, output_format): +def format_output(data, output_format, token=None, workspace="default"): """Format GraphQL response data in the specified format""" if not data: return "No data returned" @@ -82,10 +83,10 @@ def format_table_data(rows, table_name, output_format): return json.dumps({table_name: rows}, indent=2) def rows_query( - url, flow_id, query, user, collection, variables, operation_name, output_format='table' + url, flow_id, query, collection, variables, operation_name, output_format='table', token=None, workspace="default" ): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) # Parse variables if provided as JSON string parsed_variables = {} @@ -98,7 +99,6 @@ def rows_query( resp = api.rows_query( query=query, - user=user, collection=collection, variables=parsed_variables if parsed_variables else None, operation_name=operation_name @@ -135,6 +135,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -148,12 +159,6 @@ def main(): help='GraphQL query to execute', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -185,11 +190,13 @@ def main(): url=args.url, flow_id=args.flow_id, query=args.query, - user=args.user, collection=args.collection, variables=args.variables, operation_name=args.operation_name, output_format=args.format, + token=args.token, + workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 7b1ae9a6..26e03929 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -9,7 +9,8 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' @@ -44,10 +45,10 @@ def _term_str(val): return str(val) -def sparql_query(url, token, flow_id, query, user, collection, limit, - batch_size, output_format): +def sparql_query(url, token, flow_id, query, collection, limit, + batch_size, output_format, workspace="default"): - socket = Api(url=url, token=token).socket() + socket = Api(url=url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) variables = None @@ -57,7 +58,6 @@ def sparql_query(url, token, flow_id, query, user, collection, limit, for response in flow.sparql_query_stream( query=query, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -154,8 +154,14 @@ def main(): parser.add_argument( '-t', '--token', - default=os.getenv("TRUSTGRAPH_TOKEN"), - help='API bearer token (default: TRUSTGRAPH_TOKEN env var)', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -174,12 +180,6 @@ def main(): help='Read SPARQL query from file (use - for stdin)', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -228,11 +228,11 @@ def main(): token=args.token, flow_id=args.flow_id, query=query, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, output_format=args.format, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py index 9f5f8540..af2060bb 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py @@ -13,7 +13,9 @@ from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def format_output(data, output_format): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def format_output(data, output_format, token=None, workspace="default"): """Format structured query response data in the specified format""" if not data: return "No data returned" @@ -79,11 +81,11 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'): +def structured_query(url, flow_id, question, collection='default', output_format='table', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) - resp = api.structured_query(question=question, user=user, collection=collection) + resp = api.structured_query(question=question, collection=collection) # Check for errors if "error" in resp and resp["error"]: @@ -119,6 +121,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -132,12 +145,6 @@ def main(): help='Natural language question to execute', ) - parser.add_argument( - '--user', - default='trustgraph', - help='Cassandra keyspace identifier (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -159,9 +166,12 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, collection=args.collection, output_format=args.format, + token=args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py index 4086f471..e2f90f56 100644 --- a/trustgraph-cli/trustgraph/cli/list_collections.py +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -1,23 +1,22 @@ """ -List collections for a user +List collections in a workspace """ import argparse import os import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_collections(url, user, tag_filter): +def list_collections(url, tag_filter, token=None, workspace="default"): - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - collections = api.list_collections(user=user, tag_filter=tag_filter) + collections = api.list_collections(tag_filter=tag_filter) - # Handle None or empty collections if not collections or len(collections) == 0: print("No collections found.") return @@ -54,26 +53,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--tag-filter', action='append', help='Filter by tags (can be specified multiple times)' ) + parser.add_argument( + '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: list_collections( url = args.api_url, - user = args.user, - tag_filter = args.tag_filter + tag_filter = args.tag_filter, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -81,4 +87,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/list_config_items.py b/trustgraph-cli/trustgraph/cli/list_config_items.py index 5cd0f233..8bc3f683 100644 --- a/trustgraph-cli/trustgraph/cli/list_config_items.py +++ b/trustgraph-cli/trustgraph/cli/list_config_items.py @@ -9,10 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_config_items(url, config_type, format_type, token=None): +def list_config_items(url, config_type, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() keys = api.list(config_type) @@ -54,6 +56,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -63,6 +71,7 @@ def main(): config_type=args.type, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index e6d1e075..9bc87db6 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -18,7 +18,7 @@ from trustgraph.api import Api, ExplainabilityClient default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Retrieval graph @@ -86,9 +86,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -120,7 +120,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -129,7 +129,6 @@ def main(): # List all sessions — uses persistent websocket via SocketClient questions = explain_client.list_sessions( graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, limit=args.limit, ) @@ -141,7 +140,6 @@ def main(): session_type = explain_client.detect_session_type( q.uri, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection ) diff --git a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py index 20c78515..a776c59b 100644 --- a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py @@ -46,7 +46,6 @@ async def load_de(running, queue, url): "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], - "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ @@ -77,7 +76,7 @@ async def stats(running): f"Graph embeddings: {de_counts:10d}" ) -async def loader(running, de_queue, path, format, user, collection): +async def loader(running, de_queue, path, format, collection): if format == "json": @@ -96,9 +95,6 @@ async def loader(running, de_queue, path, format, user, collection): except: break - if user: - unpacked["metadata"]["user"] = user - if collection: unpacked["metadata"]["collection"] = collection @@ -148,9 +144,9 @@ async def run(running, **args): running=running, de_queue=de_q, path=args["input_file"], format=args["format"], - user=args["user"], collection=args["collection"], + collection=args["collection"], ) - + ) de_task = asyncio.create_task( @@ -178,7 +174,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -207,11 +202,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to load as (default: from input)' - ) - parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' diff --git a/trustgraph-cli/trustgraph/cli/load_kg_core.py b/trustgraph-cli/trustgraph/cli/load_kg_core.py index 008b124f..281255be 100644 --- a/trustgraph-cli/trustgraph/cli/load_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/load_kg_core.py @@ -6,20 +6,19 @@ run this utility. import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" default_collection = "default" -def load_kg_core(url, user, id, flow, collection): +def load_kg_core(url, id, flow, collection, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.load_kg_core(user = user, id = id, flow=flow, - collection=collection) + api.load_kg_core(id=id, flow=flow, collection=collection) def main(): @@ -34,12 +33,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, @@ -49,13 +42,25 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) parser.add_argument( '-C', '--collection', default=default_collection, - help=f'Collection ID (default: {default_collection}', + help=f'Collection ID (default: {default_collection})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -64,10 +69,11 @@ def main(): load_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, collection=args.collection, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/load_knowledge.py b/trustgraph-cli/trustgraph/cli/load_knowledge.py index 5e96850f..3bb4f106 100644 --- a/trustgraph-cli/trustgraph/cli/load_knowledge.py +++ b/trustgraph-cli/trustgraph/cli/load_knowledge.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class KnowledgeLoader: @@ -22,19 +22,18 @@ class KnowledgeLoader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user self.collection = collection self.document_id = document_id self.url = url self.token = token + self.workspace = workspace def load_triples_from_file(self, file) -> Iterator[Triple]: """Generator that yields Triple objects from a Turtle file""" @@ -43,11 +42,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) o_is_uri = True @@ -55,9 +52,6 @@ class KnowledgeLoader: o_value = str(e[2]) o_is_uri = False - # Create Triple object - # Note: The Triple dataclass has 's', 'p', 'o' fields as strings - # The API will handle the metadata wrapping yield Triple(s=s_value, p=p_value, o=o_value) def load_entity_contexts_from_file(self, file) -> Iterator[Tuple[str, str]]: @@ -67,11 +61,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for s, p, o in g: - # If object is a URI, skip (we only want literal contexts) if isinstance(o, rdflib.term.URIRef): continue - # If object is a literal, create entity context for subject s_str = str(s) o_str = str(o) @@ -81,11 +73,9 @@ class KnowledgeLoader: """Load triples and entity contexts using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") total_triples = 0 for file in self.files: @@ -104,7 +94,7 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, + "user": self.workspace, "collection": self.collection } ) @@ -113,20 +103,16 @@ class KnowledgeLoader: print(f"Triples loaded. Total: {total_triples}") - # Load entity contexts from all files print("Loading entity contexts...") total_contexts = 0 for file in self.files: print(f" Processing {file}...") count = 0 - # Convert tuples to the format expected by import_entity_contexts - # Entity must be in Term format: {"t": "i", "i": uri} for IRI def entity_context_generator(): nonlocal count for entity, context in self.load_entity_contexts_from_file(file): count += 1 - # Entities from RDF are URIs, use IRI term format yield { "entity": {"t": "i", "i": entity}, "context": context @@ -138,7 +124,7 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, + "user": self.workspace, "collection": self.collection } ) @@ -170,6 +156,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -182,12 +174,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -210,8 +196,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/load_sample_documents.py b/trustgraph-cli/trustgraph/cli/load_sample_documents.py index 186006a8..1907e8a2 100644 --- a/trustgraph-cli/trustgraph/cli/load_sample_documents.py +++ b/trustgraph-cli/trustgraph/cli/load_sample_documents.py @@ -14,6 +14,7 @@ from trustgraph.api.types import hash, Uri, Literal, Triple default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") from requests.adapters import HTTPAdapter @@ -656,11 +657,10 @@ documents = [ class Loader: def __init__( - self, url, user, token=None + self, url, token=None, workspace="default", ): - self.api = Api(url, token=token).library() - self.user = user + self.api = Api(url, token=token, workspace=workspace).library() def load(self, documents): @@ -689,10 +689,10 @@ class Loader: print(" adding...") self.api.add_document( - id = doc["id"], metadata = doc["metadata"], - user = self.user, kind = doc["kind"], title = doc["title"], - comments = doc["comments"], tags = doc["tags"], - document = content + id=doc["id"], metadata=doc["metadata"], + kind=doc["kind"], title=doc["title"], + comments=doc["comments"], tags=doc["tags"], + document=content, ) print(" successful.") @@ -714,26 +714,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: p = Loader( url=args.url, - user=args.user, token=args.token, + workspace=args.workspace, ) p.load(documents) diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index fa167917..b85392c9 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def load_structured_data( @@ -43,7 +44,8 @@ def load_structured_data( collection: str = 'default', dry_run: bool = False, verbose: bool = False, - token: str = None + token: str = None, + workspace: str = "default", ): """ Load structured data using a descriptor configuration. @@ -78,7 +80,7 @@ def load_structured_data( logger.info("Step 1: Analyzing data to discover best matching schema...") # Step 1: Auto-discover schema (reuse discover_schema logic) - discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not discovered_schema: logger.error("Failed to discover suitable schema automatically") print("❌ Could not automatically determine the best schema for your data.") @@ -90,7 +92,7 @@ def load_structured_data( # Step 2: Auto-generate descriptor logger.info("Step 2: Generating descriptor configuration...") - auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger) + auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace) if not auto_descriptor: logger.error("Failed to generate descriptor automatically") print("❌ Could not automatically generate descriptor configuration.") @@ -137,7 +139,7 @@ def load_structured_data( batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) # Send to TrustGraph using shared function - imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token) + imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token, workspace=workspace) # Summary format_info = descriptor.get('format', {}) @@ -172,7 +174,7 @@ def load_structured_data( logger.info(f"Sample chars: {sample_chars} characters") # Use the helper function to discover schema (get raw response for display) - response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True) + response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace) if response: # Debug: print response type and content @@ -203,7 +205,7 @@ def load_structured_data( # If no schema specified, discover it first if not schema_name: logger.info("No schema specified, auto-discovering...") - schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not schema_name: print("Error: Could not determine schema automatically.") print("Please specify a schema using --schema-name or run --discover-schema first.") @@ -213,7 +215,7 @@ def load_structured_data( logger.info(f"Target schema: {schema_name}") # Generate descriptor using helper function - descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger) + descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace) if descriptor: # Output the generated descriptor @@ -573,7 +575,7 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample return output_records, descriptor -def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): +def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, workspace="default"): """Send ExtractedObject records to TrustGraph using Python API""" from trustgraph.api import Api @@ -582,7 +584,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): logger.info(f"Importing {total_records} records to TrustGraph...") # Use Python API bulk import - api = Api(api_url, token=token) + api = Api(api_url, token=token, workspace=workspace) bulk = api.bulk() bulk.import_rows(flow=flow, rows=iter(rows)) @@ -604,7 +606,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): # Helper functions for auto mode -def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False): +def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"): """Auto-discover the best matching schema for the input data Args: @@ -627,7 +629,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get available schemas @@ -708,7 +710,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur return None -def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger): +def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"): """Auto-generate descriptor configuration for the discovered schema""" try: # Read sample data @@ -718,7 +720,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get schema definition @@ -885,12 +887,6 @@ For more information on the descriptor format, see: help='TrustGraph flow name to use for prompts and import (default: default)' ) - parser.add_argument( - '--user', - default='trustgraph', - help='User name for metadata (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -997,6 +993,12 @@ For more information on the descriptor format, see: help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() # Input validation @@ -1050,7 +1052,8 @@ For more information on the descriptor format, see: collection=args.collection, dry_run=args.dry_run, verbose=args.verbose, - token=args.token + token=args.token, + workspace=args.workspace, ) except FileNotFoundError as e: print(f"Error: File not found - {e}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/load_turtle.py b/trustgraph-cli/trustgraph/cli/load_turtle.py index adb578f5..ad5ca041 100644 --- a/trustgraph-cli/trustgraph/cli/load_turtle.py +++ b/trustgraph-cli/trustgraph/cli/load_turtle.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Loader: @@ -22,15 +22,14 @@ class Loader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user + self.workspace = workspace self.collection = collection self.document_id = document_id self.url = url @@ -43,28 +42,23 @@ class Loader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) else: o_value = str(e[2]) - # Create Triple object yield Triple(s=s_value, p=p_value, o=o_value) def run(self): """Load triples using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") for file in self.files: print(f" Processing {file}...") @@ -76,7 +70,7 @@ class Loader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, + "user": self.workspace, "collection": self.collection } ) @@ -106,6 +100,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -118,12 +118,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -146,8 +140,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/put_config_item.py b/trustgraph-cli/trustgraph/cli/put_config_item.py index d79864a4..fda9cbeb 100644 --- a/trustgraph-cli/trustgraph/cli/put_config_item.py +++ b/trustgraph-cli/trustgraph/cli/put_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigValue default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_config_item(url, config_type, key, value, token=None): +def put_config_item(url, config_type, key, value, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_value = ConfigValue(type=config_type, key=key, value=value) api.put([config_value]) @@ -63,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -78,6 +86,7 @@ def main(): key=args.key, value=value, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py index 740a224a..96db6bec 100644 --- a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py @@ -10,10 +10,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_flow_blueprint(url, blueprint_name, config, token=None): +def put_flow_blueprint(url, blueprint_name, config, token=None, + workspace="default"): - api = Api(url, token=token) + api = Api(url, token=token, workspace=workspace) blueprint_names = api.flow().put_blueprint(blueprint_name, config) @@ -36,6 +38,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -55,6 +63,7 @@ def main(): blueprint_name=args.blueprint_name, config=json.loads(args.config), token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index cd0738fe..bd3169c8 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -1,10 +1,9 @@ """ -Uses the agent service to answer a question +Puts a knowledge core into the knowledge manager via the API socket. """ import argparse import os -import textwrap import uuid import asyncio import json @@ -12,18 +11,17 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): -def read_message(unpacked, id, user): - if unpacked[0] == "ge": msg = unpacked[1] return "ge", { "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used? }, "entities": [ @@ -40,7 +38,6 @@ def read_message(unpacked, id, user): "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used by receiver? }, "triples": msg["t"], @@ -48,7 +45,7 @@ def read_message(unpacked, id, user): else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, user, id, input, token=None): +async def put(url, workspace, id, input, token=None): if not url.endswith("/"): url += "/" @@ -60,7 +57,6 @@ async def put(url, user, id, input, token=None): async with connect(url) as ws: - ge = 0 t = 0 @@ -75,7 +71,7 @@ async def put(url, user, id, input, token=None): except: break - kind, msg = read_message(unpacked, id, user) + kind, msg = read_message(unpacked, id) mid = str(uuid.uuid4()) @@ -85,10 +81,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": msg } @@ -100,10 +97,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": msg } @@ -117,7 +115,7 @@ async def put(url, user, id, input, token=None): # Retry loop, wait for right response to come back while True: - + msg = await ws.recv() msg = json.loads(msg) @@ -146,10 +144,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -176,11 +175,11 @@ def main(): asyncio.run( put( - url = args.url, - user = args.user, - id = args.id, - input = args.input, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) ) @@ -189,4 +188,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/query_graph.py b/trustgraph-cli/trustgraph/cli/query_graph.py index a2c38353..091f0599 100644 --- a/trustgraph-cli/trustgraph/cli/query_graph.py +++ b/trustgraph-cli/trustgraph/cli/query_graph.py @@ -23,9 +23,9 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def parse_inline_quoted_triple(value): @@ -285,15 +285,16 @@ def output_jsonl(triples): def query_graph( - url, flow_id, user, collection, limit, batch_size, + url, flow_id, collection, limit, batch_size, subject=None, predicate=None, obj=None, graph=None, - output_format="space", headers=False, token=None + output_format="space", headers=False, token=None, + workspace="default", ): """Query the triple store with pattern matching. Uses the API's triples_query_stream for efficient streaming delivery. """ - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) all_triples = [] @@ -305,7 +306,6 @@ def query_graph( p=predicate, o=obj, g=graph, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -456,13 +456,6 @@ def main(): help='Flow ID (default: default)' ) - std_group.add_argument( - '-U', '--user', - default=default_user, - metavar='USER', - help=f'User/keyspace (default: {default_user})' - ) - std_group.add_argument( '-C', '--collection', default=default_collection, @@ -477,6 +470,12 @@ def main(): help='Auth token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + std_group.add_argument( '-l', '--limit', type=int, @@ -550,7 +549,6 @@ def main(): query_graph( url=args.api_url, flow_id=args.flow_id, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, @@ -561,6 +559,8 @@ def main(): output_format=args.format, headers=args.headers, token=args.token, + + workspace=args.workspace, ) except json.JSONDecodeError as e: diff --git a/trustgraph-cli/trustgraph/cli/remove_library_document.py b/trustgraph-cli/trustgraph/cli/remove_library_document.py index 07a1fd59..d6500d50 100644 --- a/trustgraph-cli/trustgraph/cli/remove_library_document.py +++ b/trustgraph-cli/trustgraph/cli/remove_library_document.py @@ -4,20 +4,19 @@ Remove a document from the library import argparse import os -import uuid from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def remove_doc(url, user, id, token=None): +def remove_doc(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.remove_document(user=user, id=id) + api.remove_document(id=id) def main(): @@ -32,12 +31,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '--identifier', '--id', required=True, @@ -50,15 +43,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - remove_doc(args.url, args.user, args.identifier, token=args.token) + remove_doc( + args.url, args.identifier, + token=args.token, workspace=args.workspace, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py index ca8d25de..99d6b4db 100644 --- a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py @@ -21,7 +21,7 @@ class Running: def get(self): return self.running def stop(self): self.running = False -async def fetch_de(running, queue, user, collection, url): +async def fetch_de(running, queue, collection, url): async with aiohttp.ClientSession() as session: @@ -38,10 +38,6 @@ async def fetch_de(running, queue, user, collection, url): data = msg.json() - if user: - if data["metadata"]["user"] != user: - continue - if collection: if data["metadata"]["collection"] != collection: continue @@ -52,7 +48,6 @@ async def fetch_de(running, queue, user, collection, url): "m": { "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ @@ -119,7 +114,7 @@ async def run(running, **args): de_task = asyncio.create_task( fetch_de( running=running, - queue=q, user=args["user"], collection=args["collection"], + queue=q, collection=args["collection"], url = f"{url}api/v1/flow/{flow_id}/export/document-embeddings" ) ) @@ -148,7 +143,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -177,11 +171,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to filter on (default: no filter)' - ) - parser.add_argument( '--collection', help=f'Collection ID to filter on (default: no filter)' diff --git a/trustgraph-cli/trustgraph/cli/set_collection.py b/trustgraph-cli/trustgraph/cli/set_collection.py index dd4148ea..53aaa74d 100644 --- a/trustgraph-cli/trustgraph/cli/set_collection.py +++ b/trustgraph-cli/trustgraph/cli/set_collection.py @@ -8,15 +8,14 @@ import tabulate from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_collection(url, user, collection, name, description, tags, token=None): +def set_collection(url, collection, name, description, tags, token=None, workspace="default"): - api = Api(url, token=token).collection() + api = Api(url, token=token, workspace=workspace).collection() result = api.update_collection( - user=user, collection=collection, name=name, description=description, @@ -59,12 +58,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-n', '--name', help='Collection name' @@ -88,18 +81,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: set_collection( url = args.api_url, - user = args.user, collection = args.collection, name = args.name, description = args.description, tags = args.tags, - token = args.token + token = args.token, + workspace=args.workspace, ) except Exception as e: @@ -107,4 +106,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py index 7976adbc..002d87e8 100644 --- a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py @@ -21,6 +21,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def set_mcp_tool( url : str, @@ -31,7 +32,7 @@ def set_mcp_tool( token : str = None, ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() # Build the MCP tool configuration config = { @@ -80,6 +81,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--id', required=True, @@ -126,6 +133,8 @@ def main(): tool_url=args.tool_url, auth_token=args.auth_token, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_prompt.py b/trustgraph-cli/trustgraph/cli/set_prompt.py index bffc2cf2..dbf9c326 100644 --- a/trustgraph-cli/trustgraph/cli/set_prompt.py +++ b/trustgraph-cli/trustgraph/cli/set_prompt.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_system(url, system, token=None): +def set_system(url, system, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() api.put([ ConfigValue(type="prompt", key="system", value=json.dumps(system)) @@ -22,9 +23,9 @@ def set_system(url, system, token=None): print("System prompt set.") -def set_prompt(url, id, prompt, response, schema, token=None): +def set_prompt(url, id, prompt, response, schema, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="template-index") @@ -78,6 +79,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Prompt ID', diff --git a/trustgraph-cli/trustgraph/cli/set_token_costs.py b/trustgraph-cli/trustgraph/cli/set_token_costs.py index 19b8c703..83a356f9 100644 --- a/trustgraph-cli/trustgraph/cli/set_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/set_token_costs.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_costs(api_url, model, input_costs, output_costs, token=None): +def set_costs(api_url, model, input_costs, output_costs, token=None, workspace="default"): - api = Api(api_url, token=token).config() + api = Api(api_url, token=token, workspace=workspace).config() api.put([ ConfigValue( @@ -28,7 +29,7 @@ def set_costs(api_url, model, input_costs, output_costs, token=None): def set_prompt(url, id, prompt, response, schema): - api = Api(url) + api = Api(url, workspace=workspace) values = api.config_get([ ConfigKey(type="prompt", key="template-index") @@ -102,6 +103,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: diff --git a/trustgraph-cli/trustgraph/cli/set_tool.py b/trustgraph-cli/trustgraph/cli/set_tool.py index c6412e48..45295089 100644 --- a/trustgraph-cli/trustgraph/cli/set_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_tool.py @@ -28,6 +28,7 @@ import dataclasses default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @dataclasses.dataclass class Argument: @@ -73,9 +74,10 @@ def set_tool( state : str, applicable_states : List[str], token : str = None, + workspace : str = "default", ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="agent", key="tool-index") @@ -181,6 +183,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Unique tool identifier', @@ -303,6 +311,8 @@ def main(): state=args.state, applicable_states=args.applicable_states, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_config.py b/trustgraph-cli/trustgraph/cli/show_config.py index 6f426533..130c59b7 100644 --- a/trustgraph-cli/trustgraph/cli/show_config.py +++ b/trustgraph-cli/trustgraph/cli/show_config.py @@ -9,10 +9,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config, version = api.all() @@ -38,6 +39,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -45,6 +52,7 @@ def main(): show_config( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 90c0e452..93e0c783 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -36,7 +36,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Graphs @@ -559,9 +559,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -599,7 +599,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -609,7 +609,7 @@ def main(): trace_type = explain_client.detect_session_type( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, + user="", collection=args.collection, ) @@ -618,7 +618,7 @@ def main(): trace = explain_client.fetch_agent_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, + user="", collection=args.collection, api=api, max_content=args.max_answer, @@ -627,14 +627,14 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "agent"), indent=2)) else: - print_agent_text(trace, explain_client, api, args.user) + print_agent_text(trace, explain_client, api, "") elif trace_type == "docrag": # Fetch and display DocRAG trace trace = explain_client.fetch_docrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, + user="", collection=args.collection, api=api, max_content=args.max_answer, @@ -643,14 +643,14 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "docrag"), indent=2)) else: - print_docrag_text(trace, explain_client, api, args.user) + print_docrag_text(trace, explain_client, api, "") else: # Fetch and display GraphRAG trace trace = explain_client.fetch_graphrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, + user="", collection=args.collection, api=api, max_content=args.max_answer, @@ -661,7 +661,7 @@ def main(): else: print_graphrag_text( trace, explain_client, flow, - args.user, args.collection, + "", args.collection, api=api, show_provenance=args.show_provenance ) diff --git a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py index 4f87712c..286d6ca1 100644 --- a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py +++ b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py @@ -17,7 +17,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Predicates @@ -159,7 +159,7 @@ def get_document_content(api, user, doc_id, max_content): """Fetch document content from librarian API.""" try: library = api.library() - content = library.get_document_content(user=user, id=doc_id) + content = library.get_document_content(id=doc_id) # Try to decode as text try: @@ -331,9 +331,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -371,14 +371,14 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() try: hierarchy = build_hierarchy( socket=socket, flow_id=args.flow_id, - user=args.user, + user="", collection=args.collection, root_uri=args.document_id, api=api if args.show_content else None, diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py index 8d16d098..4924c925 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def format_parameters(params_metadata, param_type_defs): """ @@ -44,12 +45,13 @@ def format_parameters(params_metadata, param_type_defs): return "\n".join(param_list) -async def fetch_data(client): +async def fetch_data(client, workspace): """Fetch all data needed for show_flow_blueprints concurrently.""" # Round 1: list blueprints resp = await client._send_request("flow", None, { "operation": "list-blueprints", + "workspace": workspace, }) blueprint_names = resp.get("blueprint-names", []) @@ -60,6 +62,7 @@ async def fetch_data(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": name, }) for name in blueprint_names @@ -84,6 +87,7 @@ async def fetch_data(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -100,14 +104,16 @@ async def fetch_data(client): return blueprint_names, blueprints, param_type_defs -async def _show_flow_blueprints_async(url, token=None): +async def _show_flow_blueprints_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_data(client) + return await fetch_data(client, workspace) -def show_flow_blueprints(url, token=None): +def show_flow_blueprints(url, token=None, workspace="default"): blueprint_names, blueprints, param_type_defs = asyncio.run( - _show_flow_blueprints_async(url, token=token) + _show_flow_blueprints_async( + url, token=token, workspace=workspace, + ) ) if not blueprint_names: @@ -156,6 +162,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -163,6 +175,7 @@ def main(): show_flow_blueprints( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index d5d87f2c..8fec04ec 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -10,10 +10,12 @@ import os default_metrics_url = "http://localhost:8088/api/metrics" default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def dump_status(metrics_url, api_url, flow_id, token=None): +def dump_status(metrics_url, api_url, flow_id, token=None, + workspace="default"): - api = Api(api_url, token=token).flow() + api = Api(api_url, token=token, workspace=workspace).flow() flow = api.get(flow_id) blueprint_name = flow["blueprint-name"] @@ -84,11 +86,20 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - dump_status(args.metrics_url, args.api_url, args.flow_id, token=args.token) + dump_status( + args.metrics_url, args.api_url, args.flow_id, + token=args.token, workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index f7a14469..6e9479f9 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def describe_interfaces(intdefs, flow): @@ -97,17 +98,19 @@ def format_parameters(flow_params, blueprint_params_metadata, param_type_defs): return "\n".join(param_list) if param_list else "None" -async def fetch_show_flows(client): +async def fetch_show_flows(client, workspace): """Fetch all data needed for show_flows concurrently.""" # Round 1: list interfaces and list flows in parallel interface_names_resp, flow_ids_resp = await asyncio.gather( client._send_request("config", None, { "operation": "list", + "workspace": workspace, "type": "interface-description", }), client._send_request("flow", None, { "operation": "list-flows", + "workspace": workspace, }), ) @@ -115,12 +118,13 @@ async def fetch_show_flows(client): flow_ids = flow_ids_resp.get("flow-ids", []) if not flow_ids: - return {}, [], {}, {} + return {}, [], {}, {}, {} # Round 2: get all interfaces + all flows in parallel interface_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "interface-description", "key": name}], }) for name in interface_names @@ -129,6 +133,7 @@ async def fetch_show_flows(client): flow_tasks = [ client._send_request("flow", None, { "operation": "get-flow", + "workspace": workspace, "flow-id": fid, }) for fid in flow_ids @@ -163,6 +168,7 @@ async def fetch_show_flows(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": bp_name, }) for bp_name in blueprint_names @@ -186,6 +192,7 @@ async def fetch_show_flows(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -204,14 +211,16 @@ async def fetch_show_flows(client): return interface_defs, flow_ids, flows, blueprints, param_type_defs -async def _show_flows_async(url, token=None): +async def _show_flows_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_show_flows(client) + return await fetch_show_flows(client, workspace) -def show_flows(url, token=None): +def show_flows(url, token=None, workspace="default"): - result = asyncio.run(_show_flows_async(url, token=token)) + result = asyncio.run(_show_flows_async( + url, token=token, workspace=workspace, + )) interface_defs, flow_ids, flows, blueprints, param_type_defs = result @@ -269,6 +278,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -276,6 +291,7 @@ def main(): show_flows( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index 8db4edf4..6063b05a 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") # Named graph constants for convenience GRAPH_DEFAULT = "" @@ -23,14 +23,13 @@ GRAPH_SOURCE = "urn:graph:source" GRAPH_RETRIEVAL = "urn:graph:retrieval" -def show_graph(url, flow_id, user, collection, limit, batch_size, graph=None, show_graph_column=False, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, graph=None, show_graph_column=False, token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) try: for batch in flow.triples_query_stream( - user=user, collection=collection, s=None, p=None, o=None, g=graph, # Filter by named graph (None = all graphs) @@ -73,12 +72,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -91,6 +84,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -129,13 +128,13 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, graph = graph, show_graph_column = args.show_graph, token = args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_kg_cores.py b/trustgraph-cli/trustgraph/cli/show_kg_cores.py index ea295543..c9d47889 100644 --- a/trustgraph-cli/trustgraph/cli/show_kg_cores.py +++ b/trustgraph-cli/trustgraph/cli/show_kg_cores.py @@ -4,16 +4,15 @@ Shows knowledge cores import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_cores(url, user, token=None): +def show_cores(url, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() ids = api.list_kg_cores() @@ -26,7 +25,7 @@ def show_cores(url, user, token=None): def main(): parser = argparse.ArgumentParser( - prog='tg-show-flows', + prog='tg-show-kg-cores', description=__doc__, ) @@ -43,9 +42,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -54,8 +53,8 @@ def main(): show_cores( url=args.api_url, - user=args.user, token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -63,4 +62,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_library_documents.py b/trustgraph-cli/trustgraph/cli/show_library_documents.py index 6eeceb70..12a89f1a 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_documents.py +++ b/trustgraph-cli/trustgraph/cli/show_library_documents.py @@ -10,13 +10,13 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_docs(url, user, token=None): +def show_docs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - docs = api.get_documents(user=user) + docs = api.get_documents() if len(docs) == 0: print("No documents.") @@ -60,9 +60,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -71,8 +71,8 @@ def main(): show_docs( url = args.api_url, - user = args.user, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_library_processing.py b/trustgraph-cli/trustgraph/cli/show_library_processing.py index 9ab69355..700a0f83 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/show_library_processing.py @@ -4,18 +4,17 @@ import argparse import os import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_procs(url, user, token=None): +def show_procs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - procs = api.get_processings(user = user) + procs = api.get_processings() if len(procs) == 0: print("No processing objects.") @@ -52,24 +51,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: show_procs( - url = args.api_url, user = args.user, token = args.token + url=args.api_url, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -77,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py index 24cbfcfe..d5f7a1c1 100644 --- a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="mcp") @@ -64,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -71,6 +78,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_prompts.py b/trustgraph-cli/trustgraph/cli/show_prompts.py index 0e1cb2ae..cad6f317 100644 --- a/trustgraph-cli/trustgraph/cli/show_prompts.py +++ b/trustgraph-cli/trustgraph/cli/show_prompts.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="system"), @@ -85,6 +86,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -92,6 +99,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_token_costs.py b/trustgraph-cli/trustgraph/cli/show_token_costs.py index adc13ad7..c7a7bff2 100644 --- a/trustgraph-cli/trustgraph/cli/show_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/show_token_costs.py @@ -13,10 +13,11 @@ tabulate.PRESERVE_WHITESPACE = True default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() models = api.list("token-cost") @@ -68,6 +69,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -75,6 +82,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_tools.py b/trustgraph-cli/trustgraph/cli/show_tools.py index d77f1fae..51aeacbf 100644 --- a/trustgraph-cli/trustgraph/cli/show_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_tools.py @@ -19,10 +19,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="tool") @@ -116,6 +117,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -123,6 +130,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_flow.py b/trustgraph-cli/trustgraph/cli/start_flow.py index e04e241d..f65ffc49 100644 --- a/trustgraph-cli/trustgraph/cli/start_flow.py +++ b/trustgraph-cli/trustgraph/cli/start_flow.py @@ -18,10 +18,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def start_flow(url, blueprint_name, flow_id, description, parameters=None, token=None): +def start_flow(url, blueprint_name, flow_id, description, parameters=None, + token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.start( blueprint_name = blueprint_name, @@ -49,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -120,6 +128,7 @@ def main(): description = args.description, parameters = parameters, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_library_processing.py b/trustgraph-cli/trustgraph/cli/start_library_processing.py index ff87ea9f..27b5f33d 100644 --- a/trustgraph-cli/trustgraph/cli/start_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/start_library_processing.py @@ -4,19 +4,18 @@ Submits a library document for processing import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def start_processing( - url, user, document_id, id, flow, collection, tags, token=None + url, document_id, id, flow, collection, tags, + token=None, workspace="default", ): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() if tags: tags = tags.split(",") @@ -27,9 +26,8 @@ def start_processing( id = id, document_id = document_id, flow = flow, - user = user, collection = collection, - tags = tags + tags = tags, ) def main(): @@ -52,9 +50,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -91,14 +89,14 @@ def main(): try: start_processing( - url = args.api_url, - user = args.user, - document_id = args.document_id, - id = args.id, - flow = args.flow_id, - collection = args.collection, - tags = args.tags, - token = args.token, + url=args.api_url, + document_id=args.document_id, + id=args.id, + flow=args.flow_id, + collection=args.collection, + tags=args.tags, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -106,4 +104,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/stop_flow.py b/trustgraph-cli/trustgraph/cli/stop_flow.py index ae3a0415..7e2d0798 100644 --- a/trustgraph-cli/trustgraph/cli/stop_flow.py +++ b/trustgraph-cli/trustgraph/cli/stop_flow.py @@ -10,10 +10,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_flow(url, flow_id, token=None): +def stop_flow(url, flow_id, token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.stop(id = flow_id) @@ -36,6 +37,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--flow-id', required=True, @@ -50,6 +57,7 @@ def main(): url=args.api_url, flow_id=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/stop_library_processing.py b/trustgraph-cli/trustgraph/cli/stop_library_processing.py index 3d8a2c56..72a8dbb8 100644 --- a/trustgraph-cli/trustgraph/cli/stop_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/stop_library_processing.py @@ -5,21 +5,17 @@ procesing, it doesn't stop in-flight processing at the moment. import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_processing( - url, user, id, token=None -): +def stop_processing(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.stop_processing(user = user, id = id) + api.stop_processing(id=id) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -57,10 +53,10 @@ def main(): try: stop_processing( - url = args.api_url, - user = args.user, - id = args.id, - token = args.token, + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -68,4 +64,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/unload_kg_core.py b/trustgraph-cli/trustgraph/cli/unload_kg_core.py index 47f811f3..45c56067 100644 --- a/trustgraph-cli/trustgraph/cli/unload_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/unload_kg_core.py @@ -1,25 +1,21 @@ """ -Starts a load operation on a knowledge core which is already stored by -the knowledge manager. You could load a core with tg-put-kg-core and then -run this utility. +Unloads a knowledge core from a flow. """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" -default_collection = "default" -def unload_kg_core(url, user, id, flow, token=None): +def unload_kg_core(url, id, flow, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.unload_kg_core(user = user, id = id, flow=flow) + api.unload_kg_core(id=id, flow=flow) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -55,7 +51,7 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) args = parser.parse_args() @@ -64,10 +60,10 @@ def main(): unload_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 9491deaa..0b13a5a1 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -25,6 +25,7 @@ default_pulsar_url = "http://localhost:8080" default_api_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") default_ui_url = "http://localhost:8888" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class HealthChecker: @@ -210,10 +211,10 @@ def check_processors(url: str, min_processors: int, timeout: int, tr, token: Opt return False, tr.t("cli.verify_system_status.processors.error", error=str(e)) -def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow blueprints are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() blueprints = flow_api.list_blueprints() @@ -227,10 +228,10 @@ def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = Non return False, tr.t("cli.verify_system_status.flow_blueprints.error", error=str(e)) -def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flows(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow manager is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() flows = flow_api.list() @@ -242,10 +243,10 @@ def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tupl return False, tr.t("cli.verify_system_status.flows.error", error=str(e)) -def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if prompts are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) config = api.config() # Import ConfigKey here to avoid top-level import issues @@ -268,10 +269,10 @@ def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tu return False, tr.t("cli.verify_system_status.prompts.error", error=str(e)) -def check_library(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_library(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if library service is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) library_api = api.library() # Try to get documents (with default user) @@ -376,6 +377,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)' ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-v', '--verbose', action='store_true', @@ -438,6 +445,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -447,6 +455,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -456,6 +465,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() @@ -471,6 +481,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 459f6123..70489969 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", - "trustgraph-flow>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", + "trustgraph-flow>=2.4,<2.5", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 492af385..8ba85adf 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "aiohttp", "anthropic", "scylla-driver", diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index c793f9ca..8ea72260 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -26,42 +26,50 @@ class Service(ToolService): self.register_config_handler(self.on_mcp_config, types=["mcp"]) + # Per-workspace MCP service registries self.mcp_services = {} - async def on_mcp_config(self, config, version): + async def on_mcp_config(self, workspace, config, version): - logger.info(f"Got config version {version}") + logger.info( + f"Got config version {version} for workspace {workspace}" + ) if "mcp" not in config: - self.mcp_services = {} + self.mcp_services[workspace] = {} return - self.mcp_services = { + self.mcp_services[workspace] = { k: json.loads(v) for k, v in config["mcp"].items() } - async def invoke_tool(self, name, parameters): + async def invoke_tool(self, workspace, name, parameters): try: - if name not in self.mcp_services: - raise RuntimeError(f"MCP service {name} not known") + ws_services = self.mcp_services.get(workspace, {}) - if "url" not in self.mcp_services[name]: + if name not in ws_services: + raise RuntimeError( + f"MCP service {name} not known in workspace " + f"{workspace}" + ) + + if "url" not in ws_services[name]: raise RuntimeError(f"MCP service {name} URL not defined") - url = self.mcp_services[name]["url"] + url = ws_services[name]["url"] - if "remote-name" in self.mcp_services[name]: - remote_name = self.mcp_services[name]["remote-name"] + if "remote-name" in ws_services[name]: + remote_name = ws_services[name]["remote-name"] else: remote_name = name # Build headers with optional bearer token headers = {} - if "auth-token" in self.mcp_services[name]: - token = self.mcp_services[name]["auth-token"] + if "auth-token" in ws_services[name]: + token = ws_services[name]["auth-token"] headers["Authorization"] = f"Bearer {token}" logger.info(f"Invoking {remote_name} at {url}") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 6daba1a1..4d0c2689 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -264,7 +264,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=user, collection=collection, ), triples=triples, @@ -292,7 +291,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -329,7 +328,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=thought_doc_id, - user=request.user, + workspace=flow.workspace, content=act.thought, title=f"Agent Thought: {act.name}", ) @@ -360,7 +359,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=iteration_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=iter_triples, @@ -399,7 +397,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=observation_text, title=f"Agent Observation", ) @@ -420,7 +418,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=obs_triples, @@ -456,7 +453,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=answer_text, title=f"Agent Answer: {request.question[:50]}...", ) @@ -478,7 +475,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=final_triples, @@ -506,7 +502,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -532,7 +528,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=user, content=answer_text, title=f"Finding: {goal[:60]}", ) @@ -545,7 +541,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -565,7 +561,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -585,7 +581,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=user, content=answer_text, title=f"Step result: {goal[:60]}", ) @@ -598,7 +594,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -617,7 +613,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=user, content=answer_text, title="Synthesis", ) @@ -633,7 +629,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 1de31a92..aad0416a 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -109,7 +109,13 @@ class PlanThenExecutePattern(PatternBase): think = self.make_think_callback(respond, streaming) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) framing = getattr(request, 'framing', '') context = self.make_context( @@ -237,7 +243,13 @@ class PlanThenExecutePattern(PatternBase): "result": dep_result, }) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( flow, request.user, respond=respond, streaming=streaming, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 25264c26..0186347e 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -80,13 +80,20 @@ class ReactPattern(PatternBase): observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id) + # Look up the per-workspace agent + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + # Filter tools filtered_tools = self.filter_tools( - self.processor.agent.tools, request, + agent.tools, request, ) # Create temporary agent with filtered tools and optional framing - additional_context = self.processor.agent.additional_context + additional_context = agent.additional_context framing = getattr(request, 'framing', '') if framing: if additional_context: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 3d08154d..0c421323 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -76,10 +76,9 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers and meta-routers + self.agents = {} + self.meta_routers = {} self.tool_service_clients = {} @@ -91,9 +90,6 @@ class Processor(AgentService): # Aggregator for supervisor fan-in self.aggregator = Aggregator() - # Meta-router (initialised on first config load) - self.meta_router = None - self.register_config_handler( self.on_tools_config, types=["tool", "tool-service"] ) @@ -204,13 +200,13 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): request_id = str(uuid.uuid4()) doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -221,7 +217,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) future = asyncio.get_event_loop().create_future() @@ -247,9 +243,12 @@ class Processor(AgentService): def provenance_session_uri(self, session_id): return agent_session_uri(session_id) - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: tools = {} @@ -408,15 +407,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional, ) - # Re-initialise meta-router with config - self.meta_router = MetaRouter(config=config) + # Re-initialise meta-router with config for this workspace + self.meta_routers[workspace] = MetaRouter(config=config) - logger.info(f"Loaded {len(tools)} tools") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) except Exception as e: logger.error( @@ -517,8 +518,9 @@ class Processor(AgentService): if not pattern and not request.history: context = UserAwareContext(flow, request.user) - if self.meta_router: - pattern, task_type, framing = await self.meta_router.route( + meta_router = self.meta_routers.get(flow.workspace) + if meta_router: + pattern, task_type, framing = await meta_router.route( request.question, context, usage=usage, ) else: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 973a9966..6fd6233b 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -99,7 +99,13 @@ class SupervisorPattern(PatternBase): ) framing = getattr(request, 'framing', '') - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( flow, request.user, diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 1512fa83..66559d9c 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -10,6 +10,7 @@ import sys import functools import logging import uuid +from typing import Dict from datetime import datetime, timezone # Module logger @@ -73,10 +74,8 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers + self.agents: Dict[str, AgentManager] = {} # Track active tool service clients for cleanup self.tool_service_clients = {} @@ -193,7 +192,7 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. @@ -211,7 +210,7 @@ class Processor(AgentService): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -222,7 +221,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -249,9 +248,12 @@ class Processor(AgentService): self.pending_librarian_requests.pop(request_id, None) raise RuntimeError(f"Timeout saving answer document {doc_id}") - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: @@ -409,13 +411,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional ) - logger.info(f"Loaded {len(tools)} tools") - logger.info("Tool configuration reloaded.") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) + logger.info( + f"Tool configuration reloaded for workspace {workspace}." + ) except Exception as e: @@ -460,7 +466,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=request.user, collection=collection, ), triples=triples, @@ -557,17 +562,29 @@ class Processor(AgentService): await respond(r) + # Look up the agent for this workspace + workspace = flow.workspace + agent = self.agents.get(workspace) + if agent is None: + logger.error( + f"No agent configuration loaded for workspace " + f"{workspace}" + ) + raise RuntimeError( + f"No agent configuration for workspace {workspace}" + ) + # Apply tool filtering based on request groups and state filtered_tools = filter_tools_by_group_and_state( - tools=self.agent.tools, + tools=agent.tools, requested_groups=getattr(request, 'group', None), current_state=getattr(request, 'state', None) ) - + # Create temporary agent with filtered tools temp_agent = AgentManager( tools=filtered_tools, - additional_context=self.agent.additional_context + additional_context=agent.additional_context ) logger.debug("Call React") @@ -604,7 +621,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=t_doc_id, - user=request.user, + workspace=flow.workspace, content=act_decision.thought, title=f"Agent Thought: {act_decision.name}", ) @@ -629,7 +646,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=iter_uri, - user=request.user, collection=collection, ), triples=iter_triples, @@ -685,7 +701,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=f, title=f"Agent Answer: {request.question[:50]}...", ) @@ -706,7 +722,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=collection, ), triples=final_triples, @@ -763,7 +778,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=act.observation, title=f"Agent Observation", ) @@ -783,7 +798,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=collection, ), triples=obs_triples, diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index dc7b357c..a0052c79 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -95,7 +95,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -144,7 +144,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -168,7 +168,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -179,7 +178,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 3f31beb9..c3935e4b 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -92,7 +92,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -140,7 +140,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -164,7 +164,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index 6c897f6b..36af6026 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -9,42 +9,8 @@ from ... tables.config import ConfigTableStore # Module logger logger = logging.getLogger(__name__) -class ConfigurationClass: - - async def keys(self): - return await self.table_store.get_keys(self.type) - - async def values(self): - vals = await self.table_store.get_values(self.type) - return { - v[0]: v[1] - for v in vals - } - - async def get(self, key): - return await self.table_store.get_value(self.type, key) - - async def put(self, key, value): - return await self.table_store.put_config(self.type, key, value) - - async def delete(self, key): - return await self.table_store.delete_key(self.type, key) - - async def has(self, key): - val = await self.table_store.get_value(self.type, key) - return val is not None - class Configuration: - # FIXME: The state is held internally. This only works if there's - # one config service. Should be more than one, and use a - # back-end state store. - - # FIXME: This has state now, but does it address all of the above? - # REVIEW: Above - - # FIXME: Some version vs config race conditions - def __init__(self, push, host, username, password, keyspace): # External function to respond to update @@ -60,34 +26,17 @@ class Configuration: async def get_version(self): return await self.table_store.get_version() - def get(self, type): - - c = ConfigurationClass() - c.table_store = self.table_store - c.type = type - - return c - async def handle_get(self, v): - # for k in v.keys: - # if k.type not in self or k.key not in self[k.type]: - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) + workspace = v.workspace values = [ ConfigValue( type = k.type, key = k.key, - value = await self.table_store.get_value(k.type, k.key) + value = await self.table_store.get_value( + workspace, k.type, k.key + ) ) for k in v.keys ] @@ -96,43 +45,19 @@ class Configuration: version = await self.get_version(), values = values, ) - + async def handle_list(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = "No such type", - # ), - # ) - return ConfigResponse( version = await self.get_version(), - directory = await self.table_store.get_keys(v.type), + directory = await self.table_store.get_keys( + v.workspace, v.type + ), ) async def handle_getvalues(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) - - vals = await self.table_store.get_values(v.type) + vals = await self.table_store.get_values(v.workspace, v.type) values = map( lambda x: ConfigValue( @@ -146,39 +71,63 @@ class Configuration: values = list(values), ) + async def handle_getvalues_all_ws(self, v): + """Fetch all values of a given type across all workspaces. + Used by shared processors to load type-scoped config at + startup without enumerating workspaces separately.""" + + vals = await self.table_store.get_values_all_ws(v.type) + + values = [ + ConfigValue( + workspace = row[0], + type = v.type, + key = row[1], + value = row[2], + ) + for row in vals + ] + + return ConfigResponse( + version = await self.get_version(), + values = values, + ) + async def handle_delete(self, v): + workspace = v.workspace types = list(set(k.type for k in v.keys)) for k in v.keys: - - await self.table_store.delete_key(k.type, k.key) + await self.table_store.delete_key(workspace, k.type, k.key) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) async def handle_put(self, v): + workspace = v.workspace types = list(set(k.type for k in v.values)) for k in v.values: - - await self.table_store.put_config(k.type, k.key, k.value) + await self.table_store.put_config( + workspace, k.type, k.key, k.value + ) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) - async def get_config(self): + async def get_config(self, workspace): - table = await self.table_store.get_all() + table = await self.table_store.get_all_for_workspace(workspace) config = {} @@ -191,7 +140,7 @@ class Configuration: async def handle_config(self, v): - config = await self.get_config() + config = await self.get_config(v.workspace) return ConfigResponse( version = await self.get_version(), @@ -200,7 +149,20 @@ class Configuration: async def handle(self, msg): - logger.debug(f"Handling config message: {msg.operation}") + logger.debug( + f"Handling config message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + # getvalues-all-ws spans all workspaces, so no workspace + # required; everything else is workspace-scoped. + if msg.operation != "getvalues-all-ws" and not msg.workspace: + return ConfigResponse( + error=Error( + type = "bad-request", + message = "Workspace is required" + ) + ) if msg.operation == "get": @@ -214,6 +176,10 @@ class Configuration: resp = await self.handle_getvalues(msg) + elif msg.operation == "getvalues-all-ws": + + resp = await self.handle_getvalues_all_ws(msg) + elif msg.operation == "delete": resp = await self.handle_delete(msg) diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index fe44b852..56a54ee0 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -128,18 +128,21 @@ class Processor(AsyncProcessor): await self.push() # Startup poke: empty types = everything await self.config_request_consumer.start() - async def push(self, types=None): + async def push(self, changes=None): version = await self.config.get_version() resp = ConfigPush( version = version, - types = types or [], + changes = changes or {}, ) await self.config_push_producer.send(resp) - logger.info(f"Pushed config poke version {version}, types={resp.types}") + logger.info( + f"Pushed config poke version {version}, " + f"changes={resp.changes}" + ) async def on_config_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index d03d4ed6..ab5f78f0 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -33,7 +33,7 @@ class KnowledgeManager: logger.info("Deleting knowledge core...") await self.table_store.delete_kg_core( - request.user, request.id + request.workspace, request.id ) await respond( @@ -63,7 +63,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -81,7 +81,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) @@ -100,7 +100,7 @@ class KnowledgeManager: async def list_kg_cores(self, request, respond): - ids = await self.table_store.list_kg_cores(request.user) + ids = await self.table_store.list_kg_cores(request.workspace) await respond( KnowledgeResponse( @@ -114,12 +114,14 @@ class KnowledgeManager: async def put_kg_core(self, request, respond): + workspace = request.workspace + if request.triples: - await self.table_store.add_triples(request.triples) + await self.table_store.add_triples(workspace, request.triples) if request.graph_embeddings: await self.table_store.add_graph_embeddings( - request.graph_embeddings + workspace, request.graph_embeddings ) await respond( @@ -178,10 +180,15 @@ class KnowledgeManager: if request.flow is None: raise RuntimeError("Flow ID must be specified") - if request.flow not in self.flow_config.flows: - raise RuntimeError("Invalid flow") + workspace = request.workspace + ws_flows = self.flow_config.flows.get(workspace, {}) + if request.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow {request.flow} for workspace " + f"{workspace}" + ) - flow = self.flow_config.flows[request.flow] + flow = ws_flows[request.flow] if "interfaces" not in flow: raise RuntimeError("No defined interfaces") @@ -257,7 +264,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -272,7 +279,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 93017c30..15e8feb6 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -124,19 +124,21 @@ class Processor(AsyncProcessor): await self.knowledge_request_consumer.start() await self.knowledge_response_producer.start() - async def on_knowledge_config(self, config, version): + async def on_knowledge_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } else: - self.flows = {} + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") async def process_request(self, v, id): diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 40b8c566..3436ca51 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -200,7 +200,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -215,7 +215,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -243,7 +243,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -265,7 +265,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -277,7 +276,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 7f9ca71d..f3eb3881 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -93,7 +93,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -114,7 +114,7 @@ class Processor(FlowProcessor): content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) # Content is base64 encoded @@ -157,7 +157,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -179,7 +179,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -191,7 +190,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 362bdec9..0ac3a860 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -69,19 +69,26 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.register_config_handler(self.on_schema_config, types=["schema"]) self.register_config_handler(self.on_collection_config, types=["collection"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -115,13 +122,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -149,23 +162,29 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and compute embeddings""" obj = msg.value() + workspace = flow.workspace logger.info( f"Computing embeddings for {len(obj.values)} rows, " - f"schema {obj.schema_name}, doc {obj.metadata.id}" + f"schema {obj.schema_name}, doc {obj.metadata.id}, " + f"workspace {workspace}" ) # Validate collection exists before processing - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): logger.warning( - f"Collection {obj.metadata.collection} for user {obj.metadata.user} " + f"Collection {obj.metadata.collection} for workspace {workspace} " f"does not exist in config. Dropping message." ) return - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return # Get all index names for this schema diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index ce8d6aae..285b956c 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -75,24 +75,36 @@ class Processor(FlowProcessor): ) ) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -107,7 +119,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), triples = triples, @@ -120,7 +131,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), entities = entity_contexts, @@ -170,13 +180,24 @@ class Processor(FlowProcessor): try: v = msg.value() + workspace = flow.workspace # Extract chunk text chunk_text = v.chunk.decode('utf-8') - logger.debug("Processing chunk for agent extraction") + logger.debug( + f"Processing chunk for agent extraction, " + f"workspace {workspace}" + ) - prompt = self.manager.render( + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration for workspace {workspace}" + ) + return + + prompt = manager.render( self.template_id, { "text": chunk_text diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 9b5bbb79..31f45ae9 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -213,7 +213,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch @@ -227,7 +226,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index e024ad40..a05f4dfe 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -109,20 +109,22 @@ class Processor(FlowProcessor): # Register config handler for ontology updates self.register_config_handler(self.on_ontology_config, types=["ontology"]) - # Shared components (not flow-specific) - self.ontology_loader = OntologyLoader() + # Per-workspace ontology loaders + self.ontology_loaders = {} # workspace -> OntologyLoader self.text_processor = TextProcessor() - # Per-flow components (each flow gets its own embedder/vector store/selector) - self.flow_components = {} # flow_id -> {embedder, vector_store, selector} + # Per-flow components (each flow gets its own embedder/vector + # store/selector). Keyed by id(flow) — Flow objects are unique + # per (workspace, flow), so this is implicitly workspace-scoped. + self.flow_components = {} # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) - # Track loaded ontology version - self.current_ontology_version = None - self.loaded_ontology_ids = set() + # Per-workspace ontology version tracking + self.current_ontology_versions = {} # workspace -> version + self.loaded_ontology_ids = {} # workspace -> set of ids async def initialize_flow_components(self, flow): """Initialize per-flow OntoRAG components. @@ -167,17 +169,23 @@ class Processor(FlowProcessor): vector_store=vector_store ) - # Embed all loaded ontologies for this flow - if self.ontology_loader.get_all_ontologies(): - logger.info(f"Embedding ontologies for flow {flow_id}") - for ont_id, ontology in self.ontology_loader.get_all_ontologies().items(): + workspace = flow.workspace + + # Embed all loaded ontologies for this workspace + loader = self.ontology_loaders.get(workspace) + if loader is not None and loader.get_all_ontologies(): + logger.info( + f"Embedding ontologies for flow {flow_id} " + f"(workspace {workspace})" + ) + for ont_id, ontology in loader.get_all_ontologies().items(): await ontology_embedder.embed_ontology(ontology) logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}") # Initialize ontology selector ontology_selector = OntologySelector( ontology_embedder=ontology_embedder, - ontology_loader=self.ontology_loader, + ontology_loader=loader, top_k=self.top_k, similarity_threshold=self.similarity_threshold ) @@ -187,7 +195,8 @@ class Processor(FlowProcessor): 'embedder': ontology_embedder, 'vector_store': vector_store, 'selector': ontology_selector, - 'dimension': dimension + 'dimension': dimension, + 'workspace': workspace, } logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})") @@ -197,31 +206,27 @@ class Processor(FlowProcessor): logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True) raise - async def on_ontology_config(self, config, version): - """ - Handle ontology configuration updates from ConfigPush queue. - - Parses and stores ontologies. Embedding happens per-flow on first message. - - Called automatically when: - - Processor starts (gets full config history via start_of_messages=True) - - Config service pushes updates (immediate event-driven notification) - - Args: - config: Full configuration map - config[type][key] = value - version: Config version number (monotonically increasing) - """ + async def on_ontology_config(self, workspace, config, version): + """Handle ontology configuration updates for a workspace.""" try: - logger.info(f"Received ontology config update, version={version}") + logger.info( + f"Received ontology config update, " + f"version={version} workspace={workspace}" + ) - # Skip if we've already processed this version - if version == self.current_ontology_version: - logger.debug(f"Already at version {version}, skipping") + # Skip if we've already processed this version for this workspace + if version == self.current_ontology_versions.get(workspace): + logger.debug( + f"Already at version {version} for {workspace}, " + f"skipping" + ) return # Extract ontology configurations if "ontology" not in config: - logger.warning("No 'ontology' section in config") + logger.warning( + f"No 'ontology' section in config for {workspace}" + ) return ontology_configs = config["ontology"] @@ -235,38 +240,65 @@ class Processor(FlowProcessor): logger.error(f"Failed to parse ontology '{ont_id}': {e}") continue - logger.info(f"Loaded {len(ontologies)} ontology definitions") + logger.info( + f"Loaded {len(ontologies)} ontology definitions " + f"for {workspace}" + ) - # Determine what changed (for incremental updates) + # Determine what changed for this workspace + ws_loaded_ids = self.loaded_ontology_ids.get(workspace, set()) new_ids = set(ontologies.keys()) - added_ids = new_ids - self.loaded_ontology_ids - removed_ids = self.loaded_ontology_ids - new_ids - updated_ids = new_ids & self.loaded_ontology_ids # May have changed content + added_ids = new_ids - ws_loaded_ids + removed_ids = ws_loaded_ids - new_ids + updated_ids = new_ids & ws_loaded_ids # May have changed content if added_ids: - logger.info(f"New ontologies: {added_ids}") + logger.info(f"New ontologies in {workspace}: {added_ids}") if removed_ids: - logger.info(f"Removed ontologies: {removed_ids}") + logger.info(f"Removed ontologies in {workspace}: {removed_ids}") if updated_ids: - logger.info(f"Updated ontologies: {updated_ids}") + logger.info(f"Updated ontologies in {workspace}: {updated_ids}") - # Update ontology loader's internal state - self.ontology_loader.update_ontologies(ontologies) + # Get or create per-workspace loader + loader = self.ontology_loaders.get(workspace) + if loader is None: + loader = OntologyLoader() + self.ontology_loaders[workspace] = loader + loader.update_ontologies(ontologies) - # Clear all flow components to force re-embedding with new ontologies + # Clear flow components for this workspace to force + # re-embedding with new ontologies. if added_ids or removed_ids or updated_ids: - logger.info("Clearing flow components to trigger re-embedding") - self.flow_components.clear() + self._clear_workspace_flow_components(workspace) # Update tracking - self.current_ontology_version = version - self.loaded_ontology_ids = new_ids + self.current_ontology_versions[workspace] = version + self.loaded_ontology_ids[workspace] = new_ids - logger.info(f"Ontology config update complete, version={version}") + logger.info( + f"Ontology config update complete for {workspace}, " + f"version={version}" + ) except Exception as e: logger.error(f"Failed to process ontology config: {e}", exc_info=True) + def _clear_workspace_flow_components(self, workspace): + """Drop cached flow components belonging to the given workspace + so they're re-initialised on next message with fresh ontology + embeddings.""" + to_remove = [ + fid for fid, comp in self.flow_components.items() + if comp.get("workspace") == workspace + ] + if to_remove: + logger.info( + f"Clearing {len(to_remove)} flow components for " + f"workspace {workspace}" + ) + for fid in to_remove: + del self.flow_components[fid] + async def on_message(self, msg, consumer, flow): """Process incoming chunk message.""" v = msg.value() @@ -624,7 +656,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=triples, @@ -637,7 +668,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), entities=entities, diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 8068a23d..ee3e2ed2 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -207,7 +207,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 973bb3d7..f1dd4fe0 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -84,32 +84,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): - + try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -124,21 +131,27 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def extract_objects_for_schema(self, text: str, schema_name: str, schema: RowSchema, flow) -> List[Dict[str, Any]]: """Extract objects from text for a specific schema""" @@ -234,18 +247,26 @@ class Processor(FlowProcessor): """Process incoming chunk and extract objects""" v = msg.value() - logger.info(f"Extracting objects from chunk {v.metadata.id}...") + workspace = flow.workspace + logger.info( + f"Extracting objects from chunk {v.metadata.id} " + f"(workspace {workspace})..." + ) chunk_text = v.chunk.decode("utf-8") - # If no schemas configured, log warning and return - if not self.schemas: - logger.warning("No schemas configured - skipping extraction") + # If no schemas configured for this workspace, log and return + ws_schemas = self.schemas.get(workspace, {}) + if not ws_schemas: + logger.warning( + f"No schemas configured for workspace {workspace} " + f"- skipping extraction" + ) return try: # Extract objects for each configured schema - for schema_name, schema in self.schemas.items(): + for schema_name, schema in ws_schemas.items(): logger.debug(f"Extracting {schema_name} objects from chunk") @@ -274,7 +295,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=f"{v.metadata.id}:{schema_name}", root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), schema_name=schema_name, diff --git a/trustgraph-flow/trustgraph/flow/service/flow.py b/trustgraph-flow/trustgraph/flow/service/flow.py index b864faf9..b8a6f8ce 100644 --- a/trustgraph-flow/trustgraph/flow/service/flow.py +++ b/trustgraph-flow/trustgraph/flow/service/flow.py @@ -17,14 +17,18 @@ class FlowConfig: self.config = config self.pubsub = pubsub - # Cache for parameter type definitions to avoid repeated lookups + # Per-workspace cache for parameter type definitions + # Keyed by (workspace, type-name) self.param_type_cache = {} - async def resolve_parameters(self, flow_blueprint, user_params): + async def resolve_parameters( + self, workspace, flow_blueprint, user_params + ): """ Resolve parameters by merging user-provided values with defaults. Args: + workspace: Workspace containing the parameter-type definitions flow_blueprint: The flow blueprint definition dict user_params: User-provided parameters dict (may be None or empty) @@ -55,24 +59,25 @@ class FlowConfig: # Look up the parameter type definition param_type = param_meta.get("type") if param_type: + cache_key = (workspace, param_type) # Check cache first - if param_type not in self.param_type_cache: + if cache_key not in self.param_type_cache: try: # Fetch parameter type definition from config store type_def = await self.config.get( - "parameter-type", param_type + workspace, "parameter-type", param_type ) if type_def: - self.param_type_cache[param_type] = json.loads(type_def) + self.param_type_cache[cache_key] = json.loads(type_def) else: logger.warning(f"Parameter type '{param_type}' not found in config") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} except Exception as e: logger.error(f"Error fetching parameter type '{param_type}': {e}") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} # Apply default from type definition (as string) - type_def = self.param_type_cache[param_type] + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -94,8 +99,9 @@ class FlowConfig: else: # Controller has no value, try to get default from type definition param_type = param_meta.get("type") - if param_type and param_type in self.param_type_cache: - type_def = self.param_type_cache[param_type] + cache_key = (workspace, param_type) if param_type else None + if cache_key and cache_key in self.param_type_cache: + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -114,7 +120,9 @@ class FlowConfig: async def handle_list_blueprints(self, msg): - names = list(await self.config.keys("flow-blueprint")) + names = list(await self.config.keys( + msg.workspace, "flow-blueprint" + )) return FlowResponse( error = None, @@ -126,14 +134,14 @@ class FlowConfig: return FlowResponse( error = None, blueprint_definition = await self.config.get( - "flow-blueprint", msg.blueprint_name + msg.workspace, "flow-blueprint", msg.blueprint_name ), ) async def handle_put_blueprint(self, msg): await self.config.put( - "flow-blueprint", + msg.workspace, "flow-blueprint", msg.blueprint_name, msg.blueprint_definition ) @@ -145,7 +153,9 @@ class FlowConfig: logger.debug(f"Flow config message: {msg}") - await self.config.delete("flow-blueprint", msg.blueprint_name) + await self.config.delete( + msg.workspace, "flow-blueprint", msg.blueprint_name + ) return FlowResponse( error = None, @@ -153,7 +163,7 @@ class FlowConfig: async def handle_list_flows(self, msg): - names = list(await self.config.keys("flow")) + names = list(await self.config.keys(msg.workspace, "flow")) return FlowResponse( error = None, @@ -162,7 +172,9 @@ class FlowConfig: async def handle_get_flow(self, msg): - flow_data = await self.config.get("flow", msg.flow_id) + flow_data = await self.config.get( + msg.workspace, "flow", msg.flow_id + ) flow = json.loads(flow_data) return FlowResponse( @@ -174,37 +186,49 @@ class FlowConfig: async def handle_start_flow(self, msg): + workspace = msg.workspace + if msg.blueprint_name is None: raise RuntimeError("No blueprint name") if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id in await self.config.keys("flow"): + if msg.flow_id in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow already exists") if msg.description is None: raise RuntimeError("No description") - if msg.blueprint_name not in await self.config.keys("flow-blueprint"): + if msg.blueprint_name not in await self.config.keys( + workspace, "flow-blueprint" + ): raise RuntimeError("Blueprint does not exist") cls = json.loads( - await self.config.get("flow-blueprint", msg.blueprint_name) + await self.config.get( + workspace, "flow-blueprint", msg.blueprint_name + ) ) # Resolve parameters by merging user-provided values with defaults user_params = msg.parameters if msg.parameters else {} - parameters = await self.resolve_parameters(cls, user_params) + parameters = await self.resolve_parameters( + workspace, cls, user_params + ) # Log the resolved parameters for debugging logger.debug(f"User provided parameters: {user_params}") logger.debug(f"Resolved parameters (with defaults): {parameters}") - # Apply parameter substitution to template replacement function + # Apply parameter substitution to template replacement function. + # {workspace} is substituted from msg.workspace to isolate + # queue names across workspaces. def repl_template_with_params(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", msg.blueprint_name ).replace( "{id}", msg.flow_id @@ -253,7 +277,7 @@ class FlowConfig: json.dumps(entry), )) - await self.config.put_many(updates) + await self.config.put_many(workspace, updates) def repl_interface(i): return { @@ -270,7 +294,7 @@ class FlowConfig: interfaces = {} await self.config.put( - "flow", msg.flow_id, + workspace, "flow", msg.flow_id, json.dumps({ "description": msg.description, "blueprint-name": msg.blueprint_name, @@ -283,68 +307,77 @@ class FlowConfig: error = None, ) - async def ensure_existing_flow_topics(self): - """Ensure topics exist for all already-running flows. + async def ensure_existing_flow_topics(self, workspaces): + """Ensure topics exist for all already-running flows across + the given workspaces. Called on startup to handle flows that were started before this version of the flow service was deployed, or before a restart. """ - flow_ids = await self.config.keys("flow") + for workspace in workspaces: + flow_ids = await self.config.keys(workspace, "flow") - for flow_id in flow_ids: - try: - flow_data = await self.config.get("flow", flow_id) - if flow_data is None: - continue - - flow = json.loads(flow_data) - - blueprint_name = flow.get("blueprint-name") - if blueprint_name is None: - continue - - # Skip flows that are mid-shutdown - if flow.get("status") == "stopping": - continue - - parameters = flow.get("parameters", {}) - - blueprint_data = await self.config.get( - "flow-blueprint", blueprint_name - ) - if blueprint_data is None: - logger.warning( - f"Blueprint '{blueprint_name}' not found for " - f"flow '{flow_id}', skipping topic creation" + for flow_id in flow_ids: + try: + flow_data = await self.config.get( + workspace, "flow", flow_id ) - continue + if flow_data is None: + continue - cls = json.loads(blueprint_data) + flow = json.loads(flow_data) - def repl_template(tmp): - result = tmp.replace( - "{blueprint}", blueprint_name - ).replace( - "{id}", flow_id + blueprint_name = flow.get("blueprint-name") + if blueprint_name is None: + continue + + # Skip flows that are mid-shutdown + if flow.get("status") == "stopping": + continue + + parameters = flow.get("parameters", {}) + + blueprint_data = await self.config.get( + workspace, "flow-blueprint", blueprint_name ) - for param_name, param_value in parameters.items(): - result = result.replace( - f"{{{param_name}}}", str(param_value) + if blueprint_data is None: + logger.warning( + f"Blueprint '{blueprint_name}' not found " + f"for flow '{workspace}/{flow_id}', skipping " + f"topic creation" ) - return result + continue - topics = self._collect_flow_topics(cls, repl_template) - for topic in topics: - await self.pubsub.ensure_topic(topic) + cls = json.loads(blueprint_data) - logger.info( - f"Ensured topics for existing flow '{flow_id}'" - ) + def repl_template(tmp): + result = tmp.replace( + "{workspace}", workspace + ).replace( + "{blueprint}", blueprint_name + ).replace( + "{id}", flow_id + ) + for param_name, param_value in parameters.items(): + result = result.replace( + f"{{{param_name}}}", str(param_value) + ) + return result - except Exception as e: - logger.error( - f"Failed to ensure topics for flow '{flow_id}': {e}" - ) + topics = self._collect_flow_topics(cls, repl_template) + for topic in topics: + await self.pubsub.ensure_topic(topic) + + logger.info( + f"Ensured topics for existing flow " + f"'{workspace}/{flow_id}'" + ) + + except Exception as e: + logger.error( + f"Failed to ensure topics for flow " + f"'{workspace}/{flow_id}': {e}" + ) def _collect_flow_topics(self, cls, repl_template): """Collect unique topic identifiers from the blueprint. @@ -395,13 +428,17 @@ class FlowConfig: async def handle_stop_flow(self, msg): + workspace = msg.workspace + if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id not in await self.config.keys("flow"): + if msg.flow_id not in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow ID invalid") - flow = json.loads(await self.config.get("flow", msg.flow_id)) + flow = json.loads( + await self.config.get(workspace, "flow", msg.flow_id) + ) if "blueprint-name" not in flow: raise RuntimeError("Internal error: flow has no flow blueprint") @@ -410,11 +447,15 @@ class FlowConfig: parameters = flow.get("parameters", {}) cls = json.loads( - await self.config.get("flow-blueprint", blueprint_name) + await self.config.get( + workspace, "flow-blueprint", blueprint_name + ) ) def repl_template(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", blueprint_name ).replace( "{id}", msg.flow_id @@ -431,7 +472,7 @@ class FlowConfig: # The config push tells processors to shut down their consumers. flow["status"] = "stopping" await self.config.put( - "flow", msg.flow_id, json.dumps(flow) + workspace, "flow", msg.flow_id, json.dumps(flow) ) # Delete all processor config entries for this flow. @@ -444,13 +485,13 @@ class FlowConfig: deletes.append((f"processor:{processor}", variant)) - await self.config.delete_many(deletes) + await self.config.delete_many(workspace, deletes) # Phase 2: Delete topics with retries, then remove the flow record. await self._delete_topics(topics) - if msg.flow_id in await self.config.keys("flow"): - await self.config.delete("flow", msg.flow_id) + if msg.flow_id in await self.config.keys(workspace, "flow"): + await self.config.delete(workspace, "flow", msg.flow_id) return FlowResponse( error = None, @@ -458,7 +499,18 @@ class FlowConfig: async def handle(self, msg): - logger.debug(f"Handling flow message: {msg.operation}") + logger.debug( + f"Handling flow message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + if not msg.workspace: + return FlowResponse( + error=Error( + type="bad-request", + message="Workspace is required", + ), + ) if msg.operation == "list-blueprints": resp = await self.handle_list_blueprints(msg) diff --git a/trustgraph-flow/trustgraph/flow/service/service.py b/trustgraph-flow/trustgraph/flow/service/service.py index e1997452..74077ccb 100644 --- a/trustgraph-flow/trustgraph/flow/service/service.py +++ b/trustgraph-flow/trustgraph/flow/service/service.py @@ -103,7 +103,12 @@ class Processor(AsyncProcessor): await self.pubsub.ensure_topic(self.flow_request_topic) await self.config_client.start() - await self.flow.ensure_existing_flow_topics() + + # Discover workspaces with existing flow config and ensure + # their topics exist before we start accepting requests. + workspaces = await self.config_client.workspaces_for_type("flow") + await self.flow.ensure_existing_flow_topics(workspaces) + await self.flow_request_consumer.start() async def on_flow_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index c721a46a..5bc781a9 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -30,6 +30,7 @@ class ConfigReceiver: self.flow_handlers = [] + # Per-workspace flow tracking: {workspace: {flow_id: flow_def}} self.flows = {} self.config_version = 0 @@ -43,7 +44,7 @@ class ConfigReceiver: v = msg.value() notify_version = v.version - notify_types = set(v.types) + changes = v.changes # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -53,20 +54,27 @@ class ConfigReceiver: ) return - # Gateway cares about flow config - if notify_types and "flow" not in notify_types: + # Gateway cares about flow config — check if any flow + # types changed in any workspace + flow_workspaces = changes.get("flow", []) + if changes and not flow_workspaces: logger.debug( f"Ignoring config notify v{notify_version}, " - f"no flow types in {notify_types}" + f"no flow changes" ) self.config_version = notify_version return logger.info( - f"Config notify v{notify_version}, fetching config..." + f"Config notify v{notify_version} " + f"types={list(changes.keys())}, fetching config..." ) - await self.fetch_and_apply() + # Refresh config for each affected workspace + for workspace in flow_workspaces: + await self.fetch_and_apply_workspace(workspace) + + self.config_version = notify_version except Exception as e: logger.error( @@ -98,20 +106,25 @@ class ConfigReceiver: response_metrics=config_resp_metrics, ) - async def fetch_and_apply(self, retry=False): - """Fetch full config and apply flow changes. + async def fetch_and_apply_workspace(self, workspace, retry=False): + """Fetch config for a single workspace and apply flow changes. If retry=True, keeps retrying until successful.""" while True: try: - logger.info("Fetching config from config service...") + logger.info( + f"Fetching config for workspace {workspace}..." + ) client = self._create_config_client() try: await client.start() resp = await client.request( - ConfigRequest(operation="config"), + ConfigRequest( + operation="config", + workspace=workspace, + ), timeout=10, ) finally: @@ -137,18 +150,22 @@ class ConfigReceiver: flows = config.get("flow", {}) + ws_flows = self.flows.get(workspace, {}) + wanted = list(flows.keys()) - current = list(self.flows.keys()) + current = list(ws_flows.keys()) for k in wanted: if k not in current: - self.flows[k] = json.loads(flows[k]) - await self.start_flow(k, self.flows[k]) + ws_flows[k] = json.loads(flows[k]) + await self.start_flow(workspace, k, ws_flows[k]) for k in current: if k not in wanted: - await self.stop_flow(k, self.flows[k]) - del self.flows[k] + await self.stop_flow(workspace, k, ws_flows[k]) + del ws_flows[k] + + self.flows[workspace] = ws_flows return @@ -164,27 +181,91 @@ class ConfigReceiver: ) return - async def start_flow(self, id, flow): + async def fetch_all_workspaces(self, retry=False): + """Fetch config for all workspaces at startup. + Discovers workspaces via the config service getvalues-all-ws + operation on the flow type.""" - logger.info(f"Starting flow: {id}") + while True: + + try: + logger.info("Discovering workspaces with flows...") + + client = self._create_config_client() + try: + await client.start() + + # Discover workspaces that have any flow config + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type="flow", + ), + timeout=10, + ) + + if resp.error: + raise RuntimeError( + f"Config error: {resp.error.message}" + ) + + workspaces = { + v.workspace for v in resp.values if v.workspace + } + + # Always include the default workspace, even if + # empty, so that newly-created flows in it can be + # picked up by subsequent notifications. + workspaces.add("default") + + logger.info( + f"Found workspaces with flows: {workspaces}" + ) + + finally: + await client.stop() + + # Fetch and apply config for each workspace + for workspace in workspaces: + await self.fetch_and_apply_workspace( + workspace, retry=retry + ) + + return + + except Exception as e: + if retry: + logger.warning( + f"Workspace fetch failed: {e}, retrying in 2s..." + ) + await asyncio.sleep(2) + continue + logger.error( + f"Workspace fetch exception: {e}", exc_info=True + ) + return + + async def start_flow(self, workspace, id, flow): + + logger.info(f"Starting flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.start_flow(id, flow) + await handler.start_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True ) - async def stop_flow(self, id, flow): + async def stop_flow(self, workspace, id, flow): - logger.info(f"Stopping flow: {id}") + logger.info(f"Stopping flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.stop_flow(id, flow) + await handler.stop_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True @@ -218,7 +299,7 @@ class ConfigReceiver: # Fetch current config (subscribe-then-fetch pattern) # Retry until config service is available - await self.fetch_and_apply(retry=True) + await self.fetch_all_workspaces(retry=True) logger.info( "Config loader initialised, waiting for notifys..." diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 3a37c4e3..6696afbe 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -16,7 +16,7 @@ class CoreExport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") response = await ok() @@ -41,7 +41,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -65,7 +64,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -78,7 +76,7 @@ class CoreExport: await kr.process( { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, }, responder diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index 0ca07319..d03d4efd 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -17,7 +17,7 @@ class CoreImport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") kr = KnowledgeRequestor( backend = self.backend, @@ -43,12 +43,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "triples": msg["t"], @@ -61,12 +60,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "entities": [ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py index e70bf6de..2992d99f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py @@ -14,12 +14,12 @@ class DocumentStreamExport: async def process(self, data, error, ok, request): - user = request.query.get("user") + workspace = request.query.get("workspace", "default") document_id = request.query.get("document-id") chunk_size = int(request.query.get("chunk-size", 1024 * 1024)) - if not user or not document_id: - return await error("Missing required parameters: user, document-id") + if not document_id: + return await error("Missing required parameter: document-id") response = await ok() @@ -45,7 +45,7 @@ class DocumentStreamExport: await lr.process( { "operation": "stream-document", - "user": user, + "workspace": workspace, "document-id": document_id, "chunk-size": chunk_size, }, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index de0fe52d..91e47aaf 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -48,7 +48,6 @@ class EntityContextsImport: elt = EntityContexts( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 7c7dc915..3e246335 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -48,7 +48,6 @@ class GraphEmbeddingsImport: elt = GraphEmbeddings( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 592120b1..f3db3290 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -116,18 +116,20 @@ class DispatcherManager: # Format: {"config": {"request": "...", "response": "..."}, ...} self.queue_overrides = queue_overrides or {} + # Flows keyed by (workspace, flow_id) self.flows = {} + # Dispatchers keyed by (workspace, flow_id, kind) self.dispatchers = {} self.dispatcher_lock = asyncio.Lock() - async def start_flow(self, id, flow): - logger.info(f"Starting flow {id}") - self.flows[id] = flow + async def start_flow(self, workspace, id, flow): + logger.info(f"Starting flow {workspace}/{id}") + self.flows[(workspace, id)] = flow return - async def stop_flow(self, id, flow): - logger.info(f"Stopping flow {id}") - del self.flows[id] + async def stop_flow(self, workspace, id, flow): + logger.info(f"Stopping flow {workspace}/{id}") + del self.flows[(workspace, id)] return def dispatch_global_service(self): @@ -203,18 +205,20 @@ class DispatcherManager: async def process_flow_import(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in import_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -242,18 +246,20 @@ class DispatcherManager: async def process_flow_export(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in export_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -286,22 +292,36 @@ class DispatcherManager: async def process_flow_service(self, data, responder, params): + # Workspace can come from URL or from request body, defaulting + # to "default". Having it in the URL allows gateway routing to + # be workspace-aware without touching the body. + workspace = params.get("workspace") + if not workspace and isinstance(data, dict): + workspace = data.get("workspace") + if not workspace: + workspace = "default" + flow = params.get("flow") kind = params.get("kind") - return await self.invoke_flow_service(data, responder, flow, kind) + return await self.invoke_flow_service( + data, responder, workspace, flow, kind, + ) - async def invoke_flow_service(self, data, responder, flow, kind): + async def invoke_flow_service( + self, data, responder, workspace, flow, kind, + ): - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") - key = (flow, kind) + key = (workspace, flow, kind) if key not in self.dispatchers: async with self.dispatcher_lock: if key not in self.dispatchers: - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] if kind not in intf_defs: raise RuntimeError("This kind not supported by flow") @@ -314,8 +334,8 @@ class DispatcherManager: request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", + consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request", ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index fabd5c44..3d610dca 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -47,7 +47,9 @@ class Mux: raise RuntimeError("Bad message") await self.q.put(( - data["id"], data.get("flow"), + data["id"], + data.get("workspace", "default"), + data.get("flow"), data["service"], data["request"] )) @@ -87,8 +89,10 @@ class Mux: # worker[0] still running, move on break - async def start_request_task(self, ws, id, flow, svc, request, workers): - + async def start_request_task( + self, ws, id, workspace, flow, svc, request, workers, + ): + # Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS while len(workers) > MAX_OUTSTANDING_REQUESTS: @@ -106,19 +110,23 @@ class Mux: }) worker = asyncio.create_task( - self.request_task(id, request, responder, flow, svc) + self.request_task( + id, request, responder, workspace, flow, svc, + ) ) workers.append(worker) - async def request_task(self, id, request, responder, flow, svc): + async def request_task( + self, id, request, responder, workspace, flow, svc, + ): try: if flow: await self.dispatcher_manager.invoke_flow_service( - request, responder, flow, svc + request, responder, workspace, flow, svc, ) else: @@ -148,7 +156,7 @@ class Mux: # Get next request on queue item = await asyncio.wait_for(self.q.get(), 1) - id, flow, svc, request = item + id, workspace, flow, svc, request = item except TimeoutError: continue @@ -172,7 +180,7 @@ class Mux: try: await self.start_request_task( - self.ws, id, flow, svc, request, workers + self.ws, id, workspace, flow, svc, request, workers ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index ad634cab..8f92fa59 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -53,7 +53,6 @@ class RowsImport: elt = ExtractedObject( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), schema_name=data["schema_name"], diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 7267e320..28b0ded5 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -38,7 +38,6 @@ def serialize_triples(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "triples": serialize_subgraph(message.triples), @@ -50,7 +49,6 @@ def serialize_graph_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -68,7 +66,6 @@ def serialize_entity_contexts(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -86,7 +83,6 @@ def serialize_document_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "chunks": [ @@ -120,8 +116,8 @@ def serialize_document_metadata(message): if message.metadata: ret["metadata"] = serialize_subgraph(message.metadata) - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.tags is not None: ret["tags"] = message.tags @@ -144,8 +140,8 @@ def serialize_processing_metadata(message): if message.flow: ret["flow"] = message.flow - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.collection: ret["collection"] = message.collection @@ -164,7 +160,7 @@ def to_document_metadata(x): title = x.get("title", None), comments = x.get("comments", None), metadata = to_subgraph(x["metadata"]), - user = x.get("user", None), + workspace = x.get("workspace", None), tags = x.get("tags", None), ) @@ -175,7 +171,7 @@ def to_processing_metadata(x): document_id = x.get("document-id", None), time = x.get("time", None), flow = x.get("flow", None), - user = x.get("user", None), + workspace = x.get("workspace", None), collection = x.get("collection", None), tags = x.get("tags", None), ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 37f123fa..358faa8d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -49,7 +49,6 @@ class TriplesImport: metadata=Metadata( id=data["metadata"]["id"], root=data["metadata"].get("root", ""), - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), triples=to_subgraph(data["triples"]), diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py index 34ce1de8..a450aded 100644 --- a/trustgraph-flow/trustgraph/librarian/collection_manager.py +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -20,7 +20,6 @@ logger = logging.getLogger(__name__) def metadata_to_dict(metadata: CollectionMetadata) -> dict: """Convert CollectionMetadata to dictionary for JSON serialization""" return { - 'user': metadata.user, 'collection': metadata.collection, 'name': metadata.name, 'description': metadata.description, @@ -92,38 +91,38 @@ class CollectionManager: self.pending_config_requests[response_id + "_response"] = response self.pending_config_requests[response_id].set() - async def ensure_collection_exists(self, user: str, collection: str): + async def ensure_collection_exists(self, workspace: str, collection: str): """ Ensure a collection exists, creating it if necessary Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ try: # Check if collection exists via config service request = ConfigRequest( operation='get', - keys=[ConfigKey(type='collection', key=f'{user}:{collection}')] + workspace=workspace, + keys=[ConfigKey(type='collection', key=collection)] ) response = await self.send_config_request(request) # Validate response if not response.values or len(response.values) == 0: - raise Exception(f"Invalid response from config service when checking collection {user}/{collection}") + raise Exception(f"Invalid response from config service when checking collection {workspace}/{collection}") # Check if collection exists (value not None means it exists) if response.values[0].value is not None: - logger.debug(f"Collection {user}/{collection} already exists") + logger.debug(f"Collection {workspace}/{collection} already exists") return # Collection doesn't exist (value is None), proceed to create # Create new collection with default metadata - logger.info(f"Auto-creating collection {user}/{collection}") + logger.info(f"Auto-creating collection {workspace}/{collection}") metadata = CollectionMetadata( - user=user, collection=collection, name=collection, # Default name to collection ID description="", @@ -132,9 +131,10 @@ class CollectionManager: request = ConfigRequest( operation='put', + workspace=workspace, values=[ConfigValue( type='collection', - key=f'{user}:{collection}', + key=collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -144,7 +144,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {user}/{collection} auto-created in config service") + logger.info(f"Collection {workspace}/{collection} auto-created in config service") except Exception as e: logger.error(f"Error ensuring collection exists: {e}") @@ -161,9 +161,10 @@ class CollectionManager: CollectionManagementResponse with list of collections """ try: - # Get all collections from config service + # Get all collections in this workspace from config service config_request = ConfigRequest( operation='getvalues', + workspace=request.workspace, type='collection' ) @@ -172,15 +173,12 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config query failed: {response.error.message}") - # Parse collections and filter by user + # Every value in this workspace is a collection for this user collections = [] for config_value in response.values: - if ":" in config_value.key: - coll_user, coll_name = config_value.key.split(":", 1) - if coll_user == request.user: - metadata_dict = json.loads(config_value.value) - metadata = CollectionMetadata(**metadata_dict) - collections.append(metadata) + metadata_dict = json.loads(config_value.value) + metadata = CollectionMetadata(**metadata_dict) + collections.append(metadata) # Apply tag filtering if specified if request.tag_filter: @@ -221,7 +219,6 @@ class CollectionManager: tags = list(request.tags) if request.tags else [] metadata = CollectionMetadata( - user=request.user, collection=request.collection, name=name, description=description, @@ -231,9 +228,10 @@ class CollectionManager: # Send put request to config service config_request = ConfigRequest( operation='put', + workspace=request.workspace, values=[ConfigValue( type='collection', - key=f'{request.user}:{request.collection}', + key=request.collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -243,7 +241,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} updated in config service") + logger.info(f"Collection {request.workspace}/{request.collection} updated in config service") # Config service will trigger config push automatically # Storage services will receive update and create/update collections @@ -269,12 +267,13 @@ class CollectionManager: CollectionManagementResponse indicating success or failure """ try: - logger.info(f"Deleting collection {request.user}/{request.collection}") + logger.info(f"Deleting collection {request.workspace}/{request.collection}") # Send delete request to config service config_request = ConfigRequest( operation='delete', - keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')] + workspace=request.workspace, + keys=[ConfigKey(type='collection', key=request.collection)] ) response = await self.send_config_request(config_request) @@ -282,7 +281,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config delete failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} deleted from config service") + logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service") # Config service will trigger config push automatically # Storage services will receive update and delete collections diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 77232650..39a85809 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -48,7 +48,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document already exists") @@ -78,7 +78,7 @@ class Librarian: logger.debug("Removing document...") if not await self.table_store.document_exists( - request.user, + request.workspace, request.document_id, ): raise RuntimeError("Document does not exist") @@ -89,17 +89,17 @@ class Librarian: logger.debug(f"Cascade deleting child document {child.id}") try: child_object_id = await self.table_store.get_document_object_id( - child.user, + child.workspace, child.id ) await self.blob_store.remove(child_object_id) - await self.table_store.remove_document(child.user, child.id) + await self.table_store.remove_document(child.workspace, child.id) except Exception as e: logger.warning(f"Failed to delete child document {child.id}: {e}") # Now remove the parent document object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -108,7 +108,7 @@ class Librarian: # Remove doc table row await self.table_store.remove_document( - request.user, + request.workspace, request.document_id ) @@ -120,10 +120,10 @@ class Librarian: logger.debug("Updating document...") - # You can't update the document ID, user or kind. + # You can't update the document ID, workspace or kind. if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document does not exist") @@ -139,7 +139,7 @@ class Librarian: logger.debug("Getting document metadata...") doc = await self.table_store.get_document( - request.user, + request.workspace, request.document_id ) @@ -156,7 +156,7 @@ class Librarian: logger.debug("Getting document content...") object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -180,18 +180,18 @@ class Librarian: raise RuntimeError("Collection parameter is required") if await self.table_store.processing_exists( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.id ): raise RuntimeError("Processing already exists") doc = await self.table_store.get_document( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) object_id = await self.table_store.get_document_object_id( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) @@ -222,14 +222,14 @@ class Librarian: logger.debug("Removing processing metadata...") if not await self.table_store.processing_exists( - request.user, + request.workspace, request.processing_id, ): raise RuntimeError("Processing object does not exist") # Remove doc table row await self.table_store.remove_processing( - request.user, + request.workspace, request.processing_id ) @@ -239,7 +239,7 @@ class Librarian: async def list_documents(self, request): - docs = await self.table_store.list_documents(request.user) + docs = await self.table_store.list_documents(request.workspace) # Filter out child documents and answer documents by default include_children = getattr(request, 'include_children', False) @@ -256,7 +256,7 @@ class Librarian: async def list_processing(self, request): - procs = await self.table_store.list_processing(request.user) + procs = await self.table_store.list_processing(request.workspace) return LibrarianResponse( processing_metadatas = procs, @@ -276,7 +276,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -312,14 +312,14 @@ class Librarian: "kind": request.document_metadata.kind, "title": request.document_metadata.title, "comments": request.document_metadata.comments, - "user": request.document_metadata.user, + "workspace": request.document_metadata.workspace, "tags": request.document_metadata.tags, }) # Store session in Cassandra await self.table_store.create_upload_session( upload_id=upload_id, - user=request.document_metadata.user, + workspace=request.document_metadata.workspace, document_id=request.document_metadata.id, document_metadata=doc_meta_json, s3_upload_id=s3_upload_id, @@ -352,7 +352,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to upload to this session") # Validate chunk index @@ -419,7 +419,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to complete this upload") # Verify all chunks received @@ -457,7 +457,7 @@ class Librarian: kind=doc_meta_dict["kind"], title=doc_meta_dict.get("title", ""), comments=doc_meta_dict.get("comments", ""), - user=doc_meta_dict["user"], + workspace=doc_meta_dict["workspace"], tags=doc_meta_dict.get("tags", []), metadata=[], # Triples not supported in chunked upload yet ) @@ -488,7 +488,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to abort this upload") # Abort S3 multipart upload @@ -520,7 +520,7 @@ class Librarian: ) # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to view this upload") chunks_received = session["chunks_received"] @@ -548,11 +548,11 @@ class Librarian: async def list_uploads(self, request): """ - List all in-progress uploads for a user. + List all in-progress uploads for a workspace. """ - logger.debug(f"Listing uploads for user {request.user}") + logger.debug(f"Listing uploads for user {request.workspace}") - sessions = await self.table_store.list_upload_sessions(request.user) + sessions = await self.table_store.list_upload_sessions(request.workspace) upload_sessions = [ UploadSession( @@ -591,7 +591,7 @@ class Librarian: # Verify parent exists if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.parent_id ): raise RequestError( @@ -599,7 +599,7 @@ class Librarian: ) if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -665,7 +665,7 @@ class Librarian: ) object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index ed005298..c24a5fe8 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -277,18 +277,22 @@ class Processor(AsyncProcessor): """Forward config responses to collection manager""" await self.collection_manager.on_config_response(message, consumer, flow) - async def on_librarian_config(self, config, version): + async def on_librarian_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } + else: + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") def __del__(self): @@ -345,7 +349,6 @@ class Processor(AsyncProcessor): metadata=Metadata( id=doc_uri, root=document.id, - user=processing.user, collection=processing.collection, ), triples=all_triples, @@ -363,10 +366,15 @@ class Processor(AsyncProcessor): logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}") - if processing.flow not in self.flows: - raise RuntimeError("Invalid flow ID") + workspace = processing.workspace + ws_flows = self.flows.get(workspace, {}) + if processing.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow ID {processing.flow} for workspace " + f"{workspace}" + ) - flow = self.flows[processing.flow] + flow = ws_flows[processing.flow] if document.kind == "text/plain": kind = "text-load" @@ -386,7 +394,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -398,7 +405,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -429,9 +435,9 @@ class Processor(AsyncProcessor): """ # Ensure collection exists when processing is added if hasattr(request, 'processing_metadata') and request.processing_metadata: - user = request.processing_metadata.user + workspace = request.processing_metadata.workspace collection = request.processing_metadata.collection - await self.collection_manager.ensure_collection_exists(user, collection) + await self.collection_manager.ensure_collection_exists(workspace, collection) # Call the original add_processing method return await self.librarian.add_processing(request) diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 46460b1f..a63b60ae 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -50,30 +50,37 @@ class Processor(FlowProcessor): ) ) + # Per-workspace price tables self.prices = {} self.config_key = "token-cost" - # Load token costs from the config service - async def on_cost_config(self, config, version): + async def on_cost_config(self, workspace, config, version): - logger.info(f"Loading metering configuration version {version}") + logger.info( + f"Loading metering configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) + self.prices[workspace] = {} return - config = config[self.config_key] + prices = config[self.config_key] - self.prices = { + self.prices[workspace] = { k: json.loads(v) - for k, v in config.items() + for k, v in prices.items() } - def get_prices(self, modelname): + def get_prices(self, workspace, modelname): - if modelname in self.prices: - model = self.prices[modelname] + ws_prices = self.prices.get(workspace, {}) + if modelname in ws_prices: + model = ws_prices[modelname] return model["input_price"], model["output_price"] return None, None # Return None if model is not found @@ -81,6 +88,8 @@ class Processor(FlowProcessor): v = msg.value() + workspace = flow.workspace + modelname = v.model or "unknown" num_in = v.in_token or 0 num_out = v.out_token or 0 @@ -89,7 +98,9 @@ class Processor(FlowProcessor): __class__.token_metric.labels(model=modelname, direction="input").inc(num_in) __class__.token_metric.labels(model=modelname, direction="output").inc(num_out) - model_input_price, model_output_price = self.get_prices(modelname) + model_input_price, model_output_price = self.get_prices( + workspace, modelname + ) if model_input_price == None: cost_per_call = f"Model Not Found in Price list" diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index c599ce77..5da329d3 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -66,24 +66,37 @@ class Processor(FlowProcessor): self.register_config_handler(self.on_prompt_config, types=["prompt"]) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers. Populated lazily as config + # arrives for each workspace. + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading prompt configuration version {version}") + logger.info( + f"Loading prompt configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -103,6 +116,29 @@ class Processor(FlowProcessor): # Check if streaming is requested streaming = getattr(v, 'streaming', False) + # Look up the prompt manager for this workspace. If none is + # loaded yet, the request can't be handled. + workspace = flow.workspace + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration loaded for workspace {workspace}" + ) + r = PromptResponse( + error=Error( + type="no-configuration", + message=( + f"No prompt configuration for workspace " + f"{workspace}" + ), + ), + text=None, + object=None, + end_of_stream=True, + ) + await flow("response").send(r, properties={"id": id}) + return + try: logger.debug(f"Prompt terms: {v.terms}") @@ -149,7 +185,7 @@ class Processor(FlowProcessor): return "" try: - await self.manager.invoke(kind, input, llm_streaming) + await manager.invoke(kind, input, llm_streaming) except Exception as e: logger.error(f"Prompt streaming exception: {e}", exc_info=True) raise e @@ -177,7 +213,7 @@ class Processor(FlowProcessor): return None try: - resp = await self.manager.invoke(kind, input, llm) + resp = await manager.invoke(kind, input, llm) except Exception as e: logger.error(f"Prompt invocation exception: {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 98350961..0a1d8e0f 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -31,7 +31,7 @@ class Processor(DocumentEmbeddingsQueryService): self.vecstore = DocVectors(store_uri) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -45,7 +45,7 @@ class Processor(DocumentEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 406f979c..e1bc39fc 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -48,7 +48,7 @@ class Processor(DocumentEmbeddingsQueryService): } ) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -63,7 +63,7 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" + index_name = f"d-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index f056b1c1..1d59c835 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -65,7 +65,7 @@ class Processor(DocumentEmbeddingsQueryService): """Check if collection exists (no implicit creation)""" return self.qdrant.collection_exists(collection) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -75,7 +75,7 @@ class Processor(DocumentEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" + collection = f"d_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 94eee387..1c5e8160 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -37,7 +37,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -51,7 +51,7 @@ class Processor(GraphEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit * 2 ) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index ca443a6f..f612e3e8 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -55,7 +55,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -70,7 +70,7 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" + index_name = f"t-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index df93ad8b..b8fb1361 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -71,7 +71,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -81,7 +81,7 @@ class Processor(GraphEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" + collection = f"t_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7fc20303..da654247 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -93,22 +93,22 @@ class Processor(FlowProcessor): return None - async def query_row_embeddings(self, request: RowEmbeddingsRequest): + async def query_row_embeddings(self, workspace, request: RowEmbeddingsRequest): """Execute row embeddings query""" vec = request.vector if not vec: return [] - # Find the collection for this user/collection/schema + # Find the collection for this workspace/collection/schema qdrant_collection = self.find_collection( - request.user, request.collection, request.schema_name + workspace, request.collection, request.schema_name ) if not qdrant_collection: logger.info( f"No Qdrant collection found for " - f"{request.user}/{request.collection}/{request.schema_name}" + f"{workspace}/{request.collection}/{request.schema_name}" ) return [] @@ -167,7 +167,7 @@ class Processor(FlowProcessor): ) # Execute query - matches = await self.query_row_embeddings(request) + matches = await self.query_row_embeddings(flow.workspace, request) response = RowEmbeddingsResponse( error=None, diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 019d5610..20f31403 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -87,12 +87,12 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - # GraphQL schema builder and generated schema - self.schema_builder = GraphQLSchemaBuilder() - self.graphql_schema = None + # Per-workspace GraphQL schema builders and compiled schemas + self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {} + self.graphql_schemas: Dict[str, Any] = {} # Cassandra session self.cluster = None @@ -133,17 +133,27 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} - self.schema_builder.clear() + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + + builder = GraphQLSchemaBuilder() + self.schema_builders[workspace] = builder # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) + self.graphql_schemas[workspace] = None return # Get the schemas dictionary for our type @@ -177,17 +187,23 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - self.schema_builder.add_schema(schema_name, row_schema) - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + builder.add_schema(schema_name, row_schema) + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Regenerate GraphQL schema - self.graphql_schema = self.schema_builder.build(self.query_cassandra) + # Regenerate GraphQL schema for this workspace + self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -389,16 +405,21 @@ class Processor(FlowProcessor): async def execute_graphql_query( self, + workspace: str, query: str, variables: Dict[str, Any], operation_name: Optional[str], user: str, collection: str ) -> Dict[str, Any]: - """Execute a GraphQL query""" + """Execute a GraphQL query against the workspace's schema""" - if not self.graphql_schema: - raise RuntimeError("No GraphQL schema available - no schemas loaded") + graphql_schema = self.graphql_schemas.get(workspace) + if not graphql_schema: + raise RuntimeError( + f"No GraphQL schema available for workspace {workspace} " + f"- no schemas loaded" + ) # Create context for the query context = { @@ -408,7 +429,7 @@ class Processor(FlowProcessor): } # Execute the query - result = await self.graphql_schema.execute( + result = await graphql_schema.execute( query, variable_values=variables, operation_name=operation_name, @@ -454,6 +475,7 @@ class Processor(FlowProcessor): # Execute GraphQL query result = await self.execute_graphql_query( + workspace=flow.workspace, query=request.query, variables=dict(request.variables) if request.variables else {}, operation_name=request.operation_name, diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 38488032..74dc7bbb 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,7 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - user=request.user or "trustgraph", + user=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 905aaaf2..f30f8259 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -197,15 +197,15 @@ class Processor(TriplesQueryService): ) self.table = user - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: # ensure_connection may construct a fresh # EntityCentricKnowledgeGraph which does sync schema # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per user. - await asyncio.to_thread(self.ensure_connection, query.user) + # so the event loop doesn't block on first-use per workspace. + await asyncio.to_thread(self.ensure_connection, workspace) # Extract values from query s_val = get_term_value(query.s) @@ -359,13 +359,13 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e - async def query_triples_stream(self, query): + async def query_triples_stream(self, workspace, query): """ Streaming query - yields (batch, is_final) tuples. Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, query.user) + await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index 14b24d52..9781aaaf 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -58,7 +58,7 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 37633f34..53f547fb 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -63,12 +63,11 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" + user = workspace collection = query.collection if query.collection else "default" triples = [] diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 4cb1ab21..aa952120 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -63,12 +63,11 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" + user = workspace collection = query.collection if query.collection else "default" triples = [] diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index dc7296ad..6dae21fd 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -96,19 +96,19 @@ class Processor(FlowProcessor): await super(Processor, self).start() await self.librarian.start() - async def fetch_chunk_content(self, chunk_id, user, timeout=120): + async def fetch_chunk_content(self, chunk_id, workspace, timeout=120): """Fetch chunk content from librarian. Chunks are small so single request-response is fine.""" return await self.librarian.fetch_document_text( - document_id=chunk_id, user=user, timeout=timeout, + document_id=chunk_id, workspace=workspace, timeout=timeout, ) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """Save answer content to the librarian.""" doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "DocumentRAG Answer", document_type="answer", @@ -119,7 +119,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) await self.librarian.request(request, timeout=timeout) @@ -156,8 +156,7 @@ class Processor(FlowProcessor): await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection + collection=v.collection, ), triples=triples, )) @@ -178,7 +177,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"DocumentRAG Answer: {v.query[:50]}...", ) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 15c30ba1..99c2a2ef 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -170,7 +170,7 @@ class Processor(FlowProcessor): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. @@ -188,7 +188,7 @@ class Processor(FlowProcessor): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "GraphRAG Answer", document_type="answer", @@ -199,7 +199,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -247,8 +247,7 @@ class Processor(FlowProcessor): await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection, not separate explainability collection + collection=v.collection, ), triples=triples, )) @@ -311,7 +310,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"GraphRAG Answer: {v.query[:50]}...", ) diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py index b567cc7b..091069ad 100644 --- a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -66,32 +66,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} + logger.info("NLP Query service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) + + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return - + # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -106,29 +113,37 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def phase1_select_schemas(self, question: str, flow) -> List[str]: """Phase 1: Use prompt service to select relevant schemas for the question""" logger.info("Starting Phase 1: Schema selection") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Prepare schema information for the prompt schema_info = [] - for name, schema in self.schemas.items(): + for name, schema in ws_schemas.items(): schema_desc = { "name": name, "description": schema.description, @@ -176,12 +191,14 @@ class Processor(FlowProcessor): async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]: """Phase 2: Generate GraphQL query using selected schemas""" logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Get detailed schema information for selected schemas only selected_schema_info = [] for schema_name in selected_schemas: - if schema_name in self.schemas: - schema = self.schemas[schema_name] + if schema_name in ws_schemas: + schema = ws_schemas[schema_name] schema_desc = { "name": schema_name, "description": schema.description, diff --git a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py index b878bf61..6dd79cbb 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py @@ -72,21 +72,28 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} logger.info("Structured Data Diagnosis service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -120,13 +127,19 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def on_message(self, msg, consumer, flow): """Handle incoming structured data diagnosis request""" @@ -216,15 +229,19 @@ class Processor(FlowProcessor): ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - # Get target schema - if request.schema_name not in self.schemas: + # Get target schema from this workspace's schemas + ws_schemas = self.schemas.get(flow.workspace, {}) + if request.schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{request.schema_name}' not found in configuration" + message=( + f"Schema '{request.schema_name}' not found " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[request.schema_name] + target_schema = ws_schemas[request.schema_name] # Generate descriptor using prompt service descriptor = await self.generate_descriptor_with_prompt( @@ -260,26 +277,33 @@ class Processor(FlowProcessor): return StructuredDataDiagnosisResponse(error=error, operation=request.operation) # Step 2: Use provided schema name or auto-select first available + ws_schemas = self.schemas.get(flow.workspace, {}) schema_name = request.schema_name - if not schema_name and self.schemas: - schema_name = list(self.schemas.keys())[0] + if not schema_name and ws_schemas: + schema_name = list(ws_schemas.keys())[0] logger.info(f"Auto-selected schema: {schema_name}") if not schema_name: error = Error( type="NoSchemaAvailable", - message="No schema specified and no schemas available in configuration" + message=( + f"No schema specified and no schemas available " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - if schema_name not in self.schemas: + if schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{schema_name}' not found in configuration" + message=( + f"Schema '{schema_name}' not found in " + f"configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[schema_name] + target_schema = ws_schemas[schema_name] # Step 3: Generate descriptor descriptor = await self.generate_descriptor_with_prompt( @@ -316,8 +340,9 @@ class Processor(FlowProcessor): logger.info("Processing schema-selection operation") # Prepare all schemas for the prompt - match the original config format + ws_schemas = self.schemas.get(flow.workspace, {}) all_schemas = [] - for schema_name, row_schema in self.schemas.items(): + for schema_name, row_schema in ws_schemas.items(): schema_info = { "name": row_schema.name, "description": row_schema.description, diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index f5c12441..c8d08fb8 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -33,7 +33,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): for emb in message.chunks: @@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if vec: self.vecstore.insert( vec, chunk_id, - message.metadata.user, + workspace, message.metadata.collection ) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 31a70f23..e628090a 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -88,12 +88,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -112,7 +112,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"d-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index e5e7e705..bebf944e 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -39,12 +39,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -63,7 +63,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"d_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 9346c948..d5d424e9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): for entity in message.entities: entity_value = get_term_value(entity.entity) @@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if vec: self.vecstore.insert( vec, entity_value, - message.metadata.user, + workspace, message.metadata.collection, chunk_id=entity.chunk_id or "", ) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 6a95a38d..d1fd498c 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -102,12 +102,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -126,7 +126,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"t-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 9a7672f8..8f34f4c6 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -54,12 +54,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -78,7 +78,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"t_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index 475604b6..57e1fe48 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -65,13 +65,13 @@ class Processor(FlowProcessor): v = msg.value() if v.triples: - await self.table_store.add_triples(v) + await self.table_store.add_triples(flow.workspace, v) async def on_graph_embeddings(self, msg, consumer, flow): v = msg.value() if v.entities: - await self.table_store.add_graph_embeddings(v) + await self.table_store.add_graph_embeddings(flow.workspace, v) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index a6ec4ff7..ad30e7bb 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -114,18 +114,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{embeddings.schema_name} from {embeddings.metadata.id}" ) + workspace = flow.workspace + # Validate collection exists in config before processing if not self.collection_exists( - embeddings.metadata.user, embeddings.metadata.collection + workspace, embeddings.metadata.collection ): logger.warning( - f"Collection {embeddings.metadata.collection} for user " - f"{embeddings.metadata.user} does not exist in config. " + f"Collection {embeddings.metadata.collection} for workspace " + f"{workspace} does not exist in config. " f"Dropping message." ) return - user = embeddings.metadata.user collection = embeddings.metadata.collection schema_name = embeddings.schema_name @@ -145,7 +146,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension + workspace, collection, schema_name, dimension ) self.ensure_collection(qdrant_collection, dimension) diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index d0eec2e1..b90ae9ee 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -119,19 +119,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) raise - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Track which schemas changed so we can clear partition cache - old_schema_names = set(self.schemas.keys()) + # Track which schemas changed in this workspace + old_schemas = self.schemas.get(workspace, {}) + old_schema_names = set(old_schemas.keys()) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -165,24 +173,32 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Clear partition cache for schemas that changed - # This ensures next write will re-register partitions - new_schema_names = set(self.schemas.keys()) + # Clear partition cache for schemas that changed in this workspace + new_schema_names = set(ws_schemas.keys()) changed_schemas = old_schema_names.symmetric_difference(new_schema_names) if changed_schemas: self.registered_partitions = { (col, sch) for col, sch in self.registered_partitions if sch not in changed_schemas } - logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}") + logger.info( + f"Cleared partition cache for changed schemas " + f"in {workspace}: {changed_schemas}" + ) def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -286,7 +302,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): return index_names - def register_partitions(self, keyspace: str, collection: str, schema_name: str): + def register_partitions( + self, keyspace: str, collection: str, schema_name: str, + workspace: str, + ): """ Register partition entries for a (collection, schema_name) pair. Called once on first row for each pair. @@ -295,9 +314,13 @@ class Processor(CollectionConfigHandler, FlowProcessor): if cache_key in self.registered_partitions: return - schema = self.schemas.get(schema_name) + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(schema_name) if not schema: - logger.warning(f"Cannot register partitions - schema {schema_name} not found") + logger.warning( + f"Cannot register partitions - schema {schema_name} " + f"not found in workspace {workspace}" + ) return safe_keyspace = self.sanitize_name(keyspace) @@ -338,13 +361,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and store in Cassandra""" obj = msg.value() + workspace = flow.workspace logger.info( f"Storing {len(obj.values)} rows for schema {obj.schema_name} " - f"from {obj.metadata.id}" + f"from {obj.metadata.id} (workspace {workspace})" ) # Validate collection exists before accepting writes - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): error_msg = ( f"Collection {obj.metadata.collection} does not exist. " f"Create it first via collection management API." @@ -352,13 +376,17 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(error_msg) raise ValueError(error_msg) - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return - keyspace = obj.metadata.user + keyspace = workspace collection = obj.metadata.collection schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' @@ -370,7 +398,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register partitions if first time seeing this (collection, schema_name) await asyncio.to_thread( - self.register_partitions, keyspace, collection, schema_name + self.register_partitions, + keyspace, collection, schema_name, workspace, ) safe_keyspace = self.sanitize_name(keyspace) diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 01d95c8b..8fff19d2 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -147,9 +147,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_triples(self, message): - - user = message.metadata.user + async def store_triples(self, workspace, message): # The cassandra-driver work below — connection, schema # setup, and per-triple inserts — is all synchronous. @@ -159,7 +157,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): def _do_store(): - if self.table is None or self.table != user: + if self.table is None or self.table != workspace: self.tg = None @@ -170,21 +168,21 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, ) except Exception as e: logger.error(f"Exception: {e}", exc_info=True) time.sleep(1) raise e - self.table = user + self.table = workspace for t in message.triples: # Extract values from Term objects diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 86f9a6e3..bfb0988c 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -162,13 +162,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService): ) logger.info(f"Created collection metadata node for {user}/{collection}") - async def store_triples(self, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" + async def store_triples(self, workspace, message): collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -182,14 +180,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 16a7d3ed..2ffb5e3d 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -258,14 +258,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, ) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -279,14 +277,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) # Alternative implementation using transactions # with self.io.session(database=self.db) as session: diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index f7b2d947..dca571e2 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -209,14 +209,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -230,14 +228,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index d9a8711b..8fd00427 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -72,10 +72,11 @@ class ConfigTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS config ( + workspace text, class text, key text, value text, - PRIMARY KEY (class, key) + PRIMARY KEY ((workspace, class), key) ); """); @@ -124,52 +125,63 @@ class ConfigTableStore: def prepare_statements(self): self.put_config_stmt = self.cassandra.prepare(""" - INSERT INTO config ( class, key, value ) - VALUES (?, ?, ?) - """) - - self.get_classes_stmt = self.cassandra.prepare(""" - SELECT DISTINCT class FROM config; + INSERT INTO config ( workspace, class, key, value ) + VALUES (?, ?, ?, ?) """) self.get_keys_stmt = self.cassandra.prepare(""" - SELECT key FROM config WHERE class = ?; + SELECT key FROM config + WHERE workspace = ? AND class = ?; """) self.get_value_stmt = self.cassandra.prepare(""" - SELECT value FROM config WHERE class = ? AND key = ?; + SELECT value FROM config + WHERE workspace = ? AND class = ? AND key = ?; """) self.delete_key_stmt = self.cassandra.prepare(""" DELETE FROM config - WHERE class = ? AND key = ?; + WHERE workspace = ? AND class = ? AND key = ?; """) self.get_all_stmt = self.cassandra.prepare(""" - SELECT class AS cls, key, value FROM config; + SELECT workspace, class AS cls, key, value FROM config; + """) + + self.get_all_for_workspace_stmt = self.cassandra.prepare(""" + SELECT class AS cls, key, value FROM config + WHERE workspace = ? + ALLOW FILTERING; """) self.get_values_stmt = self.cassandra.prepare(""" - SELECT key, value FROM config WHERE class = ?; + SELECT key, value FROM config + WHERE workspace = ? AND class = ?; """) - async def put_config(self, cls, key, value): + self.get_values_all_ws_stmt = self.cassandra.prepare(""" + SELECT workspace, key, value FROM config + WHERE class = ? + ALLOW FILTERING; + """) + + async def put_config(self, workspace, cls, key, value): try: await async_execute( self.cassandra, self.put_config_stmt, - (cls, key, value), + (workspace, cls, key, value), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_value(self, cls, key): + async def get_value(self, workspace, cls, key): try: rows = await async_execute( self.cassandra, self.get_value_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -179,12 +191,12 @@ class ConfigTableStore: return row[0] return None - async def get_values(self, cls): + async def get_values(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_values_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -192,18 +204,20 @@ class ConfigTableStore: return [[row[0], row[1]] for row in rows] - async def get_classes(self): + async def get_values_all_ws(self, cls): + """Return (workspace, key, value) tuples for all workspaces + with entries of the given class.""" try: rows = await async_execute( self.cassandra, - self.get_classes_stmt, - (), + self.get_values_all_ws_stmt, + (cls,), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - return [row[0] for row in rows] + return [(row[0], row[1], row[2]) for row in rows] async def get_all(self): try: @@ -216,14 +230,27 @@ class ConfigTableStore: logger.error("Exception occurred", exc_info=True) raise + return [(row[0], row[1], row[2], row[3]) for row in rows] + + async def get_all_for_workspace(self, workspace): + try: + rows = await async_execute( + self.cassandra, + self.get_all_for_workspace_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + return [(row[0], row[1], row[2]) for row in rows] - async def get_keys(self, cls): + async def get_keys(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_keys_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -231,12 +258,12 @@ class ConfigTableStore: return [row[0] for row in rows] - async def delete_key(self, cls, key): + async def delete_key(self, workspace, cls, key): try: await async_execute( self.cassandra, self.delete_key_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index b06f4862..4d729956 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -88,7 +88,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS triples ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -98,7 +98,7 @@ class KnowledgeTableStore: triples list>, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); @@ -106,7 +106,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" create table if not exists graph_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -119,20 +119,20 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS graph_embeddings_user ON - graph_embeddings ( user ); + CREATE INDEX IF NOT EXISTS graph_embeddings_workspace ON + graph_embeddings ( workspace ); """); logger.debug("document_embeddings table...") self.cassandra.execute(""" create table if not exists document_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -145,13 +145,13 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS document_embeddings_user ON - document_embeddings ( user ); + CREATE INDEX IF NOT EXISTS document_embeddings_workspace ON + document_embeddings ( workspace ); """); logger.info("Cassandra schema OK.") @@ -161,7 +161,7 @@ class KnowledgeTableStore: self.insert_triples_stmt = self.cassandra.prepare(""" INSERT INTO triples ( - id, user, document_id, + id, workspace, document_id, time, metadata, triples ) VALUES (?, ?, ?, ?, ?, ?) @@ -170,7 +170,7 @@ class KnowledgeTableStore: self.insert_graph_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO graph_embeddings ( - id, user, document_id, time, metadata, entity_embeddings + id, workspace, document_id, time, metadata, entity_embeddings ) VALUES (?, ?, ?, ?, ?, ?) """) @@ -178,45 +178,45 @@ class KnowledgeTableStore: self.insert_document_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO document_embeddings ( - id, user, document_id, time, metadata, chunks + id, workspace, document_id, time, metadata, chunks ) VALUES (?, ?, ?, ?, ?, ?) """) self.list_cores_stmt = self.cassandra.prepare(""" - SELECT DISTINCT user, document_id FROM graph_embeddings - WHERE user = ? + SELECT DISTINCT workspace, document_id FROM graph_embeddings + WHERE workspace = ? """) self.get_triples_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, triples FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_graph_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, entity_embeddings FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_document_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, chunks FROM document_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_triples_stmt = self.cassandra.prepare(""" DELETE FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_graph_embeddings_stmt = self.cassandra.prepare(""" DELETE FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) - async def add_triples(self, m): + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -232,7 +232,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_triples_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], triples, ), @@ -241,7 +241,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_graph_embeddings(self, m): + async def add_graph_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -258,7 +258,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_graph_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], entities, ), @@ -267,7 +267,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_document_embeddings(self, m): + async def add_document_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -284,7 +284,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_document_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], chunks, ), @@ -293,7 +293,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def list_kg_cores(self, user): + async def list_kg_cores(self, workspace): logger.debug("List kg cores...") @@ -301,7 +301,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.list_cores_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -313,7 +313,7 @@ class KnowledgeTableStore: return lst - async def delete_kg_core(self, user, document_id): + async def delete_kg_core(self, workspace, document_id): logger.debug("Delete kg cores...") @@ -321,7 +321,7 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -331,13 +331,13 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_triples(self, user, document_id, receiver): + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -345,7 +345,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -369,7 +369,6 @@ class KnowledgeTableStore: Triples( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), triples = triples @@ -378,7 +377,7 @@ class KnowledgeTableStore: logger.debug("Done") - async def get_graph_embeddings(self, user, document_id, receiver): + async def get_graph_embeddings(self, workspace, document_id, receiver): logger.debug("Get GE...") @@ -386,7 +385,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -409,12 +408,11 @@ class KnowledgeTableStore: GraphEmbeddings( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), entities = entities ) - ) + ) logger.debug("Done") diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index c85ae72a..86706079 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -64,7 +64,7 @@ class LibraryTableStore: self.cluster = Cluster(cassandra_host) self.cassandra = self.cluster.connect() - + logger.info("Connected.") self.ensure_cassandra_schema() @@ -76,13 +76,13 @@ class LibraryTableStore: logger.debug("Ensure Cassandra schema...") logger.debug("Keyspace...") - + # FIXME: Replication factor should be configurable self.cassandra.execute(f""" create keyspace if not exists {self.keyspace} - with replication = {{ - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 }}; """); @@ -93,7 +93,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS document ( id text, - user text, + workspace text, time timestamp, kind text, title text, @@ -103,7 +103,9 @@ class LibraryTableStore: >>, tags list, object_id uuid, - PRIMARY KEY (user, id) + parent_id text, + document_type text, + PRIMARY KEY (workspace, id) ); """); @@ -114,27 +116,6 @@ class LibraryTableStore: ON document (object_id) """); - # Add parent_id and document_type columns for child document support - logger.debug("document table parent_id column...") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD parent_id text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"parent_id column may already exist: {e}") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD document_type text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"document_type column may already exist: {e}") - logger.debug("document parent index...") self.cassandra.execute(""" @@ -150,10 +131,10 @@ class LibraryTableStore: document_id text, time timestamp, flow text, - user text, + workspace text, collection text, tags list, - PRIMARY KEY (user, id) + PRIMARY KEY (workspace, id) ); """); @@ -162,7 +143,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS upload_session ( upload_id text PRIMARY KEY, - user text, + workspace text, document_id text, document_metadata text, s3_upload_id text, @@ -176,11 +157,11 @@ class LibraryTableStore: ) WITH default_time_to_live = 86400; """); - logger.debug("upload_session user index...") + logger.debug("upload_session workspace index...") self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS upload_session_user - ON upload_session (user) + CREATE INDEX IF NOT EXISTS upload_session_workspace + ON upload_session (workspace) """); logger.info("Cassandra schema OK.") @@ -190,7 +171,7 @@ class LibraryTableStore: self.insert_document_stmt = self.cassandra.prepare(""" INSERT INTO document ( - id, user, time, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type @@ -202,25 +183,25 @@ class LibraryTableStore: UPDATE document SET time = ?, title = ?, comments = ?, metadata = ?, tags = ? - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.get_document_stmt = self.cassandra.prepare(""" SELECT time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.delete_document_stmt = self.cassandra.prepare(""" DELETE FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_document_exists_stmt = self.cassandra.prepare(""" SELECT id FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -229,7 +210,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? + WHERE workspace = ? """) self.list_document_by_tag_stmt = self.cassandra.prepare(""" @@ -237,7 +218,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND tags CONTAINS ? + WHERE workspace = ? AND tags CONTAINS ? ALLOW FILTERING """) @@ -245,7 +226,7 @@ class LibraryTableStore: INSERT INTO processing ( id, document_id, time, - flow, user, collection, + flow, workspace, collection, tags ) VALUES (?, ?, ?, ?, ?, ?, ?) @@ -253,13 +234,13 @@ class LibraryTableStore: self.delete_processing_stmt = self.cassandra.prepare(""" DELETE FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_processing_exists_stmt = self.cassandra.prepare(""" SELECT id FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -267,14 +248,14 @@ class LibraryTableStore: SELECT id, document_id, time, flow, collection, tags FROM processing - WHERE user = ? + WHERE workspace = ? """) # Upload session prepared statements self.insert_upload_session_stmt = self.cassandra.prepare(""" INSERT INTO upload_session ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at ) @@ -283,7 +264,7 @@ class LibraryTableStore: self.get_upload_session_stmt = self.cassandra.prepare(""" SELECT - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session @@ -308,25 +289,25 @@ class LibraryTableStore: total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session - WHERE user = ? + WHERE workspace = ? """) # Child document queries self.list_children_stmt = self.cassandra.prepare(""" SELECT - id, user, time, kind, title, comments, metadata, tags, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document WHERE parent_id = ? ALLOW FILTERING """) - async def document_exists(self, user, id): + async def document_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_document_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -351,7 +332,7 @@ class LibraryTableStore: self.cassandra, self.insert_document_stmt, ( - document.id, document.user, int(document.time * 1000), + document.id, document.workspace, int(document.time * 1000), document.kind, document.title, document.comments, metadata, document.tags, object_id, parent_id, document_type @@ -381,7 +362,7 @@ class LibraryTableStore: ( int(document.time * 1000), document.title, document.comments, metadata, document.tags, - document.user, document.id + document.workspace, document.id ), ) except Exception: @@ -390,7 +371,7 @@ class LibraryTableStore: logger.debug("Update complete") - async def remove_document(self, user, document_id): + async def remove_document(self, workspace, document_id): logger.info(f"Removing document {document_id}") @@ -398,7 +379,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_document_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -406,7 +387,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_documents(self, user): + async def list_documents(self, workspace): logger.debug("List documents...") @@ -414,7 +395,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_document_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -423,7 +404,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = user, + workspace = workspace, time = int(time.mktime(row[1].timetuple())), kind = row[2], title = row[3], @@ -465,7 +446,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = row[1], + workspace = row[1], time = int(time.mktime(row[2].timetuple())), kind = row[3], title = row[4], @@ -489,7 +470,7 @@ class LibraryTableStore: return lst - async def get_document(self, user, id): + async def get_document(self, workspace, id): logger.debug("Get document") @@ -497,7 +478,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -506,7 +487,7 @@ class LibraryTableStore: for row in rows: doc = DocumentMetadata( id = id, - user = user, + workspace = workspace, time = int(time.mktime(row[0].timetuple())), kind = row[1], title = row[2], @@ -529,7 +510,7 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def get_document_object_id(self, user, id): + async def get_document_object_id(self, workspace, id): logger.debug("Get document obj ID") @@ -537,7 +518,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -549,12 +530,12 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def processing_exists(self, user, id): + async def processing_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_processing_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -570,7 +551,7 @@ class LibraryTableStore: ( processing.id, processing.document_id, int(processing.time * 1000), processing.flow, - processing.user, processing.collection, + processing.workspace, processing.collection, processing.tags ), ) @@ -580,7 +561,7 @@ class LibraryTableStore: logger.debug("Add complete") - async def remove_processing(self, user, processing_id): + async def remove_processing(self, workspace, processing_id): logger.info(f"Removing processing {processing_id}") @@ -588,7 +569,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_processing_stmt, - (user, processing_id), + (workspace, processing_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -596,7 +577,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_processing(self, user): + async def list_processing(self, workspace): logger.debug("List processing objects") @@ -604,7 +585,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_processing_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -616,7 +597,7 @@ class LibraryTableStore: document_id = row[1], time = int(time.mktime(row[2].timetuple())), flow = row[3], - user = user, + workspace = workspace, collection = row[4], tags = row[5] if row[5] else [], ) @@ -632,7 +613,7 @@ class LibraryTableStore: async def create_upload_session( self, upload_id, - user, + workspace, document_id, document_metadata, s3_upload_id, @@ -652,7 +633,7 @@ class LibraryTableStore: self.cassandra, self.insert_upload_session_stmt, ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, {}, now, now ), @@ -681,7 +662,7 @@ class LibraryTableStore: for row in rows: session = { "upload_id": row[0], - "user": row[1], + "workspace": row[1], "document_id": row[2], "document_metadata": row[3], "s3_upload_id": row[4], @@ -738,16 +719,16 @@ class LibraryTableStore: logger.debug("Upload session deleted") - async def list_upload_sessions(self, user): - """List all upload sessions for a user.""" + async def list_upload_sessions(self, workspace): + """List all upload sessions for a workspace.""" - logger.debug(f"List upload sessions for {user}") + logger.debug(f"List upload sessions for {workspace}") try: rows = await async_execute( self.cassandra, self.list_upload_sessions_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index cd1d20a1..1718258f 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index 4844b104..9d955d17 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -91,7 +91,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -106,7 +106,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -141,7 +141,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -163,7 +163,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index d8879329..7169fc8b 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "python-magic", diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index 6b7d0246..b3723655 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -275,7 +275,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=doc_id, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=page_content, document_type="page" if is_page else "section", title=label, @@ -303,7 +303,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -314,7 +313,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), document_id=doc_id, @@ -356,7 +354,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=img_uri, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=img_content, document_type="image", title=f"Image from page {page_number}" if page_number else "Image", @@ -379,7 +377,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=img_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -404,13 +401,13 @@ class Processor(FlowProcessor): doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) mime_type = doc_meta.kind if doc_meta else None content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 45958ef3..f43f154d 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "google-genai", "google-api-core",