release/v2.4 -> master (#844)

This commit is contained in:
cybermaggedon 2026-04-22 15:19:57 +01:00 committed by GitHub
parent a24df8e990
commit 89cabee1b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
386 changed files with 7202 additions and 5741 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Setup packages - name: Setup packages
run: make update-package-versions VERSION=2.3.999 run: make update-package-versions VERSION=2.4.999
- name: Setup environment - name: Setup environment
run: python3 -m venv env run: python3 -m venv env

1
.gitignore vendored
View file

@ -15,4 +15,5 @@ trustgraph-parquet/trustgraph/parquet_version.py
trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-vertexai/trustgraph/vertexai_version.py
trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-unstructured/trustgraph/unstructured_version.py
trustgraph-mcp/trustgraph/mcp_version.py trustgraph-mcp/trustgraph/mcp_version.py
trustgraph/trustgraph/trustgraph_version.py
vertexai/ vertexai/

View file

@ -57,7 +57,7 @@ container-bedrock container-vertexai \
container-hf container-ocr \ container-hf container-ocr \
container-unstructured container-mcp container-unstructured container-mcp
some-containers: container-base container-flow some-containers: container-base container-flow container-unstructured
push: push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}

View file

@ -0,0 +1,309 @@
---
layout: default
title: "Data Ownership and Information Separation"
parent: "Tech Specs"
---
# Data Ownership and Information Separation
## Purpose
This document defines the logical ownership model for data in
TrustGraph: what the artefacts are, who owns them, and how they relate
to each other.
The IAM spec ([iam.md](iam.md)) describes authentication and
authorisation mechanics. This spec addresses the prior question: what
are the boundaries around data, and who owns what?
## Concepts
### Workspace
A workspace is the primary isolation boundary. It represents an
organisation, team, or independent operating unit. All data belongs to
exactly one workspace. Cross-workspace access is never permitted through
the API.
A workspace owns:
- Source documents
- Flows (processing pipeline definitions)
- Knowledge cores (stored extraction output)
- Collections (organisational units for extracted knowledge)
### Collection
A collection is an organisational unit within a workspace. It groups
extracted knowledge produced from source documents. A workspace can
have multiple collections, allowing:
- Processing the same documents with different parameters or models.
- Maintaining separate knowledge bases for different purposes.
- Deleting extracted knowledge without deleting source documents.
Collections do not own source documents. A source document exists at the
workspace level and can be processed into multiple collections.
### Source document
A source document (PDF, text file, etc.) is raw input uploaded to the
system. Documents belong to the workspace, not to a specific collection.
This is intentional. A document is an asset that exists independently
of how it is processed. The same PDF might be processed into multiple
collections with different chunking parameters or extraction models.
Tying a document to a single collection would force re-upload for each
collection.
### Flow
A flow defines a processing pipeline: which models to use, what
parameters to apply (chunk size, temperature, etc.), and how processing
services are connected. Flows belong to a workspace.
The processing services themselves (document-decoder, chunker,
embeddings, LLM completion, etc.) are shared infrastructure — they serve
all workspaces. Each flow has its own queues, keeping data from
different workspaces and flows separate as it moves through the
pipeline.
Different workspaces can define different flows. Workspace A might use
GPT-5.2 with a chunk size of 2000, while workspace B uses Claude with a
chunk size of 1000.
### Prompts
Prompts are templates that control how the LLM behaves during knowledge
extraction and query answering. They belong to a workspace, allowing
different workspaces to have different extraction strategies, response
styles, or domain-specific instructions.
### Ontology
An ontology defines the concepts, entities, and relationships that the
extraction pipeline looks for in source documents. Ontologies belong to
a workspace. A medical workspace might define ontologies around diseases,
symptoms, and treatments, while a legal workspace defines ontologies
around statutes, precedents, and obligations.
### Schemas
Schemas define structured data types for extraction. They specify what
fields to extract, their types, and how they relate. Schemas belong to
a workspace, as different workspaces extract different structured
information from their documents.
### Tools, tool services, and MCP tools
Tools define capabilities available to agents: what actions they can
take, what external services they can call. Tool services configure how
tools connect to backend services. MCP tools configure connections to
remote MCP servers, including authentication tokens. All belong to a
workspace.
### Agent patterns and agent task types
Agent patterns define agent behaviour strategies (how an agent reasons,
what steps it follows). Agent task types define the kinds of tasks
agents can perform. Both belong to a workspace, as different workspaces
may have different agent configurations.
### Token costs
Token cost definitions specify pricing for LLM token usage per model.
These belong to a workspace since different workspaces may use different
models or have different billing arrangements.
### Flow blueprints
Flow blueprints are templates for creating flows. They define the
default pipeline structure and parameters. Blueprints belong to a
workspace, allowing workspaces to define custom processing templates.
### Parameter types
Parameter types define the kinds of parameters that flows accept (e.g.
"llm-model", "temperature"), including their defaults and validation
rules. They belong to a workspace since workspaces that define custom
flows need to define the parameter types those flows use.
### Interface descriptions
Interface descriptions define the connection points of a flow — what
queues and topics it uses. They belong to a workspace since they
describe workspace-owned flows.
### Knowledge core
A knowledge core is a stored snapshot of extracted knowledge (triples
and graph embeddings). Knowledge cores belong to a workspace and can be
loaded into any collection within that workspace.
Knowledge cores serve as a portable extraction output. You process
documents through a flow, the pipeline produces triples and embeddings,
and the results can be stored as a knowledge core. That core can later
be loaded into a different collection or reloaded after a collection is
cleared.
### Extracted knowledge
Extracted knowledge is the live, queryable content within a collection:
triples in the knowledge graph, graph embeddings, and document
embeddings. It is the product of processing source documents through a
flow into a specific collection.
Extracted knowledge is scoped to a workspace and a collection. It
cannot exist without both.
### Processing record
A processing record tracks which source document was processed, through
which flow, into which collection. It links the source document
(workspace-scoped) to the extracted knowledge (workspace + collection
scoped).
## Ownership summary
| Artefact | Owned by | Shared across collections? |
|----------|----------|---------------------------|
| Workspaces | Global (platform) | N/A |
| User accounts | Global (platform) | N/A |
| API keys | Global (platform) | N/A |
| Source documents | Workspace | Yes |
| Flows | Workspace | N/A |
| Flow blueprints | Workspace | N/A |
| Prompts | Workspace | N/A |
| Ontologies | Workspace | N/A |
| Schemas | Workspace | N/A |
| Tools | Workspace | N/A |
| Tool services | Workspace | N/A |
| MCP tools | Workspace | N/A |
| Agent patterns | Workspace | N/A |
| Agent task types | Workspace | N/A |
| Token costs | Workspace | N/A |
| Parameter types | Workspace | N/A |
| Interface descriptions | Workspace | N/A |
| Knowledge cores | Workspace | Yes — can be loaded into any collection |
| Collections | Workspace | N/A |
| Extracted knowledge | Workspace + collection | No |
| Processing records | Workspace + collection | No |
## Scoping summary
### Global (system-level)
A small number of artefacts exist outside any workspace:
- **Workspace registry** — the list of workspaces itself
- **User accounts** — users reference a workspace but are not owned by
one
- **API keys** — belong to users, not workspaces
These are managed by the IAM layer and exist at the platform level.
### Workspace-owned
All other configuration and data is workspace-owned:
- Flow definitions and parameters
- Flow blueprints
- Prompts
- Ontologies
- Schemas
- Tools, tool services, and MCP tools
- Agent patterns and agent task types
- Token costs
- Parameter types
- Interface descriptions
- Collection definitions
- Knowledge cores
- Source documents
- Collections and their extracted knowledge
## Relationship between artefacts
```
Platform (global)
|
+-- Workspaces
| |
+-- User accounts (each assigned to a workspace)
| |
+-- API keys (belong to users)
Workspace
|
+-- Source documents (uploaded, unprocessed)
|
+-- Flows (pipeline definitions: models, parameters, queues)
|
+-- Flow blueprints (templates for creating flows)
|
+-- Prompts (LLM instruction templates)
|
+-- Ontologies (entity and relationship definitions)
|
+-- Schemas (structured data type definitions)
|
+-- Tools, tool services, MCP tools (agent capabilities)
|
+-- Agent patterns and agent task types (agent behaviour)
|
+-- Token costs (LLM pricing per model)
|
+-- Parameter types (flow parameter definitions)
|
+-- Interface descriptions (flow connection points)
|
+-- Knowledge cores (stored extraction snapshots)
|
+-- Collections
|
+-- Extracted knowledge (triples, embeddings)
|
+-- Processing records (links documents to collections)
```
A typical workflow:
1. A source document is uploaded to the workspace.
2. A flow defines how to process it (which models, what parameters).
3. The document is processed through the flow into a collection.
4. Processing records track what was processed.
5. Extracted knowledge (triples, embeddings) is queryable within the
collection.
6. Optionally, the extracted knowledge is stored as a knowledge core
for later reuse.
## Implementation notes
The current codebase uses a `user` field in message metadata and storage
partition keys to identify the workspace. The `collection` field
identifies the collection within that workspace. The IAM spec describes
how the gateway maps authenticated credentials to a workspace identity
and sets these fields.
For details on how each storage backend implements this scoping, see:
- [Entity-Centric Graph](entity-centric-graph.md) — Cassandra KG schema
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)
- [Collection Management](collection-management.md)
### Known inconsistencies in current implementation
- **Pipeline intermediate tables** do not include collection in their
partition keys. Re-processing the same document into a different
collection may overwrite intermediate state.
- **Processing metadata** stores collection in the row payload but not
in the partition key, making collection-based queries inefficient.
- **Upload sessions** are keyed by upload ID, not workspace. The
gateway should validate workspace ownership before allowing
operations on upload sessions.
## References
- [Identity and Access Management](iam.md)
- [Collection Management](collection-management.md)
- [Entity-Centric Graph](entity-centric-graph.md)
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)
- [Multi-Tenant Support](multi-tenant-support.md)

View file

@ -20,8 +20,8 @@ Defines shared service processors that are instantiated once per flow blueprint.
```json ```json
"class": { "class": {
"service-name:{class}": { "service-name:{class}": {
"request": "queue-pattern:{class}", "request": "queue-pattern:{workspace}:{class}",
"response": "queue-pattern:{class}", "response": "queue-pattern:{workspace}:{class}",
"settings": { "settings": {
"setting-name": "fixed-value", "setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}" "parameterized-setting": "{parameter-name}"
@ -31,11 +31,11 @@ Defines shared service processors that are instantiated once per flow blueprint.
``` ```
**Characteristics:** **Characteristics:**
- Shared across all flow instances of the same class - Shared across all flow instances of the same class within a workspace
- Typically expensive or stateless services (LLMs, embedding models) - Typically expensive or stateless services (LLMs, embedding models)
- Use `{class}` template variable for queue naming - Use `{workspace}` and `{class}` template variables for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax - Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}` - Examples: `embeddings:{workspace}:{class}`, `text-completion:{workspace}:{class}`
### 2. Flow Section ### 2. Flow Section
Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors. Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors.
@ -43,8 +43,8 @@ Defines flow-specific processors that are instantiated for each individual flow
```json ```json
"flow": { "flow": {
"processor-name:{id}": { "processor-name:{id}": {
"input": "queue-pattern:{id}", "input": "queue-pattern:{workspace}:{id}",
"output": "queue-pattern:{id}", "output": "queue-pattern:{workspace}:{id}",
"settings": { "settings": {
"setting-name": "fixed-value", "setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}" "parameterized-setting": "{parameter-name}"
@ -56,9 +56,9 @@ Defines flow-specific processors that are instantiated for each individual flow
**Characteristics:** **Characteristics:**
- Unique instance per flow - Unique instance per flow
- Handle flow-specific data and state - Handle flow-specific data and state
- Use `{id}` template variable for queue naming - Use `{workspace}` and `{id}` template variables for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax - Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}` - Examples: `chunker:{workspace}:{id}`, `pdf-decoder:{workspace}:{id}`
### 3. Interfaces Section ### 3. Interfaces Section
Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication. Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication.
@ -68,8 +68,8 @@ Interfaces can take two forms:
**Fire-and-Forget Pattern** (single queue): **Fire-and-Forget Pattern** (single queue):
```json ```json
"interfaces": { "interfaces": {
"document-load": "persistent://tg/flow/document-load:{id}", "document-load": "persistent://tg/flow/{workspace}:document-load:{id}",
"triples-store": "persistent://tg/flow/triples-store:{id}" "triples-store": "persistent://tg/flow/{workspace}:triples-store:{id}"
} }
``` ```
@ -77,8 +77,8 @@ Interfaces can take two forms:
```json ```json
"interfaces": { "interfaces": {
"embeddings": { "embeddings": {
"request": "non-persistent://tg/request/embeddings:{class}", "request": "non-persistent://tg/request/{workspace}:embeddings:{class}",
"response": "non-persistent://tg/response/embeddings:{class}" "response": "non-persistent://tg/response/{workspace}:embeddings:{class}"
} }
} }
``` ```
@ -117,6 +117,16 @@ Additional information about the flow blueprint:
### System Variables ### System Variables
#### {workspace}
- Replaced with the workspace identifier
- Isolates queue names between workspaces so that two workspaces
starting the same flow do not share queues
- Must be included in all queue name patterns to ensure workspace
isolation
- Example: `ws-acme`, `ws-globex`
- All blueprint templates must include `{workspace}` in queue name
patterns
#### {id} #### {id}
- Replaced with the unique flow instance identifier - Replaced with the unique flow instance identifier
- Creates isolated resources for each flow - Creates isolated resources for each flow

858
docs/tech-specs/iam.md Normal file
View file

@ -0,0 +1,858 @@
---
layout: default
title: "Identity and Access Management"
parent: "Tech Specs"
---
# Identity and Access Management
## Problem Statement
TrustGraph has no meaningful identity or access management. The system
relies on a single shared gateway token for authentication and an
honour-system `user` query parameter for data isolation. This creates
several problems:
- **No user identity.** There are no user accounts, no login, and no way
to know who is making a request. The `user` field in message metadata
is a caller-supplied string with no validation — any client can claim
to be any user.
- **No access control.** A valid gateway token grants unrestricted access
to every endpoint, every user's data, every collection, and every
administrative operation. There is no way to limit what an
authenticated caller can do.
- **No credential isolation.** All callers share one static token. There
is no per-user credential, no token expiration, and no rotation
mechanism. Revoking access means changing the shared token, which
affects all callers.
- **Data isolation is unenforced.** Storage backends (Cassandra, Neo4j,
Qdrant) filter queries by `user` and `collection`, but the gateway
does not prevent a caller from specifying another user's identity.
Cross-user data access is trivial.
- **No audit trail.** There is no logging of who accessed what. Without
user identity, audit logging is impossible.
These gaps make the system unsuitable for multi-user deployments,
multi-tenant SaaS, or any environment where access needs to be
controlled or audited.
## Current State
### Authentication
The API gateway supports a single shared token configured via the
`GATEWAY_SECRET` environment variable or `--api-token` CLI argument. If
unset, authentication is disabled entirely. When enabled, every HTTP
endpoint requires an `Authorization: Bearer <token>` header. WebSocket
connections pass the token as a query parameter.
Implementation: `trustgraph-flow/trustgraph/gateway/auth.py`
```python
class Authenticator:
def __init__(self, token=None, allow_all=False):
self.token = token
self.allow_all = allow_all
def permitted(self, token, roles):
if self.allow_all: return True
if self.token != token: return False
return True
```
The `roles` parameter is accepted but never evaluated. All authenticated
requests have identical privileges.
MCP tool configurations support an optional per-tool `auth-token` for
service-to-service authentication with remote MCP servers. These are
static, system-wide tokens — not per-user credentials. See
[mcp-tool-bearer-token.md](mcp-tool-bearer-token.md) for details.
### User identity
The `user` field is passed explicitly by the caller as a query parameter
(e.g. `?user=trustgraph`) or set by CLI tools. It flows through the
system in the core `Metadata` dataclass:
```python
@dataclass
class Metadata:
id: str = ""
root: str = ""
user: str = ""
collection: str = ""
```
There is no user registration, login, user database, or session
management.
### Data isolation
The `user` + `collection` pair is used at the storage layer to partition
data:
- **Cassandra**: queries filter by `user` and `collection` columns
- **Neo4j**: queries filter by `user` and `collection` properties
- **Qdrant**: vector search filters by `user` and `collection` metadata
| Layer | Isolation mechanism | Enforced by |
|-------|-------------------|-------------|
| Gateway | Single shared token | `Authenticator` class |
| Message metadata | `user` + `collection` fields | Caller (honour system) |
| Cassandra | Column filters on `user`, `collection` | Query layer |
| Neo4j | Property filters on `user`, `collection` | Query layer |
| Qdrant | Metadata filters on `user`, `collection` | Query layer |
| Pub/sub topics | Per-flow topic namespacing | Flow service |
The storage-layer isolation depends on all queries correctly filtering by
`user` and `collection`. There is no gateway-level enforcement preventing
a caller from querying another user's data by passing a different `user`
parameter.
### Configuration and secrets
| Setting | Source | Default | Purpose |
|---------|--------|---------|---------|
| `GATEWAY_SECRET` | Env var | Empty (auth disabled) | Gateway bearer token |
| `--api-token` | CLI arg | None | Gateway bearer token (overrides env) |
| `PULSAR_API_KEY` | Env var | None | Pub/sub broker auth |
| MCP `auth-token` | Config service | None | Per-tool MCP server auth |
No secrets are encrypted at rest. The gateway token and MCP tokens are
stored and transmitted in plaintext (aside from any transport-layer
encryption such as TLS).
### Capabilities that do not exist
- Per-user authentication (JWT, OAuth, SAML, API keys per user)
- User accounts or user management
- Role-based access control (RBAC)
- Attribute-based access control (ABAC)
- Per-user or per-workspace API keys
- Token expiration or rotation
- Session management
- Per-user rate limiting
- Audit logging of user actions
- Permission checks preventing cross-user data access
- Multi-workspace credential isolation
### Key files
| File | Purpose |
|------|---------|
| `trustgraph-flow/trustgraph/gateway/auth.py` | Authenticator class |
| `trustgraph-flow/trustgraph/gateway/service.py` | Gateway init, token config |
| `trustgraph-flow/trustgraph/gateway/endpoint/*.py` | Per-endpoint auth checks |
| `trustgraph-base/trustgraph/schema/core/metadata.py` | `Metadata` dataclass with `user` field |
## Technical Design
### Design principles
- **Auth at the edge.** The gateway is the single enforcement point.
Internal services trust the gateway and do not re-authenticate.
This avoids distributing credential validation across dozens of
microservices.
- **Identity from credentials, not from callers.** The gateway derives
user identity from authentication credentials. Callers can no longer
self-declare their identity via query parameters.
- **Workspace isolation by default.** Every authenticated user belongs to
a workspace. All data operations are scoped to that workspace.
Cross-workspace access is not possible through the API.
- **Extensible API contract.** The API accepts an optional workspace
parameter on every request. This allows the same protocol to support
single-workspace deployments today and multi-workspace extensions in
the future without breaking changes.
- **Simple roles, not fine-grained permissions.** A small number of
predefined roles controls what operations a user can perform. This is
sufficient for the current API surface and avoids the complexity of
per-resource permission management.
### Authentication
The gateway supports two credential types. Both are carried as a Bearer
token in the `Authorization` header for HTTP requests. The gateway
distinguishes them by format.
For WebSocket connections, credentials are not passed in the URL or
headers. Instead, the client authenticates after connecting by sending
an auth message as the first frame:
```
Client: opens WebSocket to /api/v1/socket
Server: accepts connection (unauthenticated state)
Client: sends {"type": "auth", "token": "tg_abc123..."}
Server: validates token
success → {"type": "auth-ok", "workspace": "acme"}
failure → {"type": "auth-failed", "error": "invalid token"}
```
The server rejects all non-auth messages until authentication succeeds.
The socket remains open on auth failure, allowing the client to retry
with a different token without reconnecting. The client can also send
a new auth message at any time to re-authenticate — for example, to
refresh an expiring JWT or to switch workspace. The
resolved identity (user, workspace, roles) is updated on each
successful auth.
#### API keys
For programmatic access: CLI tools, scripts, and integrations.
- Opaque tokens (e.g. `tg_a1b2c3d4e5f6...`). Not JWTs — short,
simple, easy to paste into CLI tools and headers.
- Each user has one or more API keys.
- Keys are stored hashed (SHA-256 with salt) in the IAM service. The
plaintext key is returned once at creation time and cannot be
retrieved afterwards.
- Keys can be revoked individually without affecting other users.
- Keys optionally have an expiry date. Expired keys are rejected.
On each request, the gateway resolves an API key by:
1. Hashing the token.
2. Checking a local cache (hash → user/workspace/roles).
3. On cache miss, calling the IAM service to resolve.
4. Caching the result with a short TTL (e.g. 60 seconds).
Revoked keys stop working when the cache entry expires. No push
invalidation is needed.
#### JWTs (login sessions)
For interactive access via the UI or WebSocket connections.
- A user logs in with username and password. The gateway forwards the
request to the IAM service, which validates the credentials and
returns a signed JWT.
- The JWT carries the user ID, workspace, and roles as claims.
- The gateway validates JWTs locally using the IAM service's public
signing key — no service call needed on subsequent requests.
- Token expiry is enforced by standard JWT validation at the time the
request (or WebSocket connection) is made.
- For long-lived WebSocket connections, the JWT is validated at connect
time only. The connection remains authenticated for its lifetime.
The IAM service manages the signing key. The gateway fetches the public
key at startup (or on first JWT encounter) and caches it.
#### Login endpoint
```
POST /api/v1/auth/login
{
"username": "alice",
"password": "..."
}
→ {
"token": "eyJ...",
"expires": "2026-04-20T19:00:00Z"
}
```
The gateway forwards this to the IAM service, which validates
credentials and returns a signed JWT. The gateway returns the JWT to
the caller.
#### IAM service delegation
The gateway stays thin. Its authentication logic is:
1. Extract Bearer token from header (or query param for WebSocket).
2. If the token has JWT format (dotted structure), validate the
signature locally and extract claims.
3. Otherwise, treat as an API key: hash it and check the local cache.
On cache miss, call the IAM service to resolve.
4. If neither succeeds, return 401.
All user management, key management, credential validation, and token
signing logic lives in the IAM service. The gateway is a generic
enforcement point that can be replaced without changing the IAM
service.
#### No legacy token support
The existing `GATEWAY_SECRET` shared token is removed. All
authentication uses API keys or JWTs. On first start, the bootstrap
process creates a default workspace and admin user with an initial API
key.
### User identity
A user belongs to exactly one workspace. The design supports extending
this to multi-workspace access in the future (see
[Extension points](#extension-points)).
A user record contains:
| Field | Type | Description |
|-------|------|-------------|
| `id` | string | Unique user identifier (UUID) |
| `name` | string | Display name |
| `email` | string | Email address (optional) |
| `workspace` | string | Workspace the user belongs to |
| `roles` | list[string] | Assigned roles (e.g. `["reader"]`) |
| `enabled` | bool | Whether the user can authenticate |
| `created` | datetime | Account creation timestamp |
The `workspace` field maps to the existing `user` field in `Metadata`.
This means the storage-layer isolation (Cassandra, Neo4j, Qdrant
filtering by `user` + `collection`) works without changes — the gateway
sets the `user` metadata field to the authenticated user's workspace.
### Workspaces
A workspace is an isolated data boundary. Users belong to a workspace,
and all data operations are scoped to it. Workspaces map to the existing
`user` field in `Metadata` and the corresponding Cassandra keyspace,
Qdrant collection prefix, and Neo4j property filters.
| Field | Type | Description |
|-------|------|-------------|
| `id` | string | Unique workspace identifier |
| `name` | string | Display name |
| `enabled` | bool | Whether the workspace is active |
| `created` | datetime | Creation timestamp |
All data operations are scoped to a workspace. The gateway determines
the effective workspace for each request as follows:
1. If the request includes a `workspace` parameter, validate it against
the user's assigned workspace.
- If it matches, use it.
- If it does not match, return 403. (This could be extended to
check a workspace access grant list.)
2. If no `workspace` parameter is provided, use the user's assigned
workspace.
The gateway sets the `user` field in `Metadata` to the effective
workspace ID, replacing the caller-supplied `?user=` query parameter.
This design ensures forward compatibility. Clients that pass a
workspace parameter will work unchanged if multi-workspace support is
added later. Requests for an unassigned workspace get a clear 403
rather than silent misbehaviour.
### Roles and access control
Three roles with fixed permissions:
| Role | Data operations | Admin operations | System |
|------|----------------|-----------------|--------|
| `reader` | Query knowledge graph, embeddings, RAG | None | None |
| `writer` | All reader operations + load documents, manage collections | None | None |
| `admin` | All writer operations | Config, flows, collection management, user management | Metrics |
Role checks happen at the gateway before dispatching to backend
services. Each endpoint declares the minimum role required:
| Endpoint pattern | Minimum role |
|-----------------|--------------|
| `GET /api/v1/socket` (queries) | `reader` |
| `POST /api/v1/librarian` | `writer` |
| `POST /api/v1/flow/*/import/*` | `writer` |
| `POST /api/v1/config` | `admin` |
| `GET /api/v1/flow/*` | `admin` |
| `GET /api/metrics` | `admin` |
Roles are hierarchical: `admin` implies `writer`, which implies
`reader`.
### IAM service
The IAM service is a new backend service that manages all identity and
access data. It is the authority for users, workspaces, API keys, and
credentials. The gateway delegates to it.
#### Data model
```
iam_workspaces (
id text PRIMARY KEY,
name text,
enabled boolean,
created timestamp
)
iam_users (
id text PRIMARY KEY,
workspace text,
name text,
email text,
password_hash text,
roles set<text>,
enabled boolean,
created timestamp
)
iam_api_keys (
key_hash text PRIMARY KEY,
user_id text,
name text,
expires timestamp,
created timestamp
)
```
A secondary index on `iam_api_keys.user_id` supports listing a user's
keys.
#### Responsibilities
- User CRUD (create, list, update, disable)
- Workspace CRUD (create, list, update, disable)
- API key management (create, revoke, list)
- API key resolution (hash → user/workspace/roles)
- Credential validation (username/password → signed JWT)
- JWT signing key management (initialise, rotate)
- Bootstrap (create default workspace and admin user on first start)
#### Communication
The IAM service communicates via the standard request/response pub/sub
pattern, the same as the config service. The gateway calls it to
resolve API keys and to handle login requests. User management
operations (create user, revoke key, etc.) also go through the IAM
service.
### Gateway changes
The current `Authenticator` class is replaced with a thin authentication
middleware that delegates to the IAM service:
For HTTP requests:
1. Extract Bearer token from the `Authorization` header.
2. If the token has JWT format (dotted structure):
- Validate signature locally using the cached public key.
- Extract user ID, workspace, and roles from claims.
3. Otherwise, treat as an API key:
- Hash the token and check the local cache.
- On cache miss, call the IAM service to resolve.
- Cache the result (user/workspace/roles) with a short TTL.
4. If neither succeeds, return 401.
5. If the user or workspace is disabled, return 403.
6. Check the user's role against the endpoint's minimum role. If
insufficient, return 403.
7. Resolve the effective workspace:
- If the request includes a `workspace` parameter, validate it
against the user's assigned workspace. Return 403 on mismatch.
- If no `workspace` parameter, use the user's assigned workspace.
8. Set the `user` field in the request context to the effective
workspace ID. This propagates through `Metadata` to all downstream
services.
For WebSocket connections:
1. Accept the connection in an unauthenticated state.
2. Wait for an auth message (`{"type": "auth", "token": "..."}`).
3. Validate the token using the same logic as steps 2-7 above.
4. On success, attach the resolved identity to the connection and
send `{"type": "auth-ok", ...}`.
5. On failure, send `{"type": "auth-failed", ...}` but keep the
socket open.
6. Reject all non-auth messages until authentication succeeds.
7. Accept new auth messages at any time to re-authenticate.
### CLI changes
CLI tools authenticate with API keys:
- `--api-key` argument on all CLI tools, replacing `--api-token`.
- `tg-create-workspace`, `tg-list-workspaces` for workspace management.
- `tg-create-user`, `tg-list-users`, `tg-disable-user` for user
management.
- `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` for
key management.
- `--workspace` argument on tools that operate on workspace-scoped
data.
- The API key is passed as a Bearer token in the same way as the
current shared token, so the transport protocol is unchanged.
### Audit logging
With user identity established, the gateway logs:
- Timestamp, user ID, workspace, endpoint, HTTP method, response status.
- Audit logs are written to the standard logging output (structured
JSON). Integration with external log aggregation (Loki, ELK) is a
deployment concern, not an application concern.
### Config service changes
All configuration is workspace-scoped (see
[data-ownership-model.md](data-ownership-model.md)). The config service
needs to support this.
#### Schema change
The config table adds workspace as a key dimension:
```
config (
workspace text,
class text,
key text,
value text,
PRIMARY KEY ((workspace, class), key)
)
```
#### Request format
Config requests add a `workspace` field at the request level. The
existing `(type, key)` structure is unchanged within each workspace.
**Get:**
```json
{
"operation": "get",
"workspace": "workspace-a",
"keys": [{"type": "prompt", "key": "rag-prompt"}]
}
```
**Put:**
```json
{
"operation": "put",
"workspace": "workspace-a",
"values": [{"type": "prompt", "key": "rag-prompt", "value": "..."}]
}
```
**List (all keys of a type within a workspace):**
```json
{
"operation": "list",
"workspace": "workspace-a",
"type": "prompt"
}
```
**Delete:**
```json
{
"operation": "delete",
"workspace": "workspace-a",
"keys": [{"type": "prompt", "key": "rag-prompt"}]
}
```
The workspace is set by:
- **Gateway** — from the authenticated user's workspace for API-facing
requests.
- **Internal services** — explicitly, based on `Metadata.user` from
the message being processed, or `_system` for operational config.
#### System config namespace
Processor-level operational config (logging levels, connection strings,
resource limits) is not workspace-specific. This stays in a reserved
`_system` workspace that is not associated with any user workspace.
Services read system config at startup without needing a workspace
context.
#### Config change notifications
The config notify mechanism pushes change notifications via pub/sub
when config is updated. A single update may affect multiple workspaces
and multiple config types. The notification message carries a dict of
changes keyed by config type, with each value being the list of
affected workspaces:
```json
{
"version": 42,
"changes": {
"prompt": ["workspace-a", "workspace-b"],
"schema": ["workspace-a"]
}
}
```
System config changes use the reserved `_system` workspace:
```json
{
"version": 43,
"changes": {
"logging": ["_system"]
}
}
```
This structure is keyed by type because handlers register by type. A
handler registered for `prompt` looks up `"prompt"` directly and gets
the list of affected workspaces — no iteration over unrelated types.
#### Config change handlers
The current `on_config` hook mechanism needs two modes to support shared
processing services:
- **Workspace-scoped handlers** — notify when a config type changes in a
specific workspace. The handler looks up its registered type in the
changes dict and checks if its workspace is in the list. Used by the
gateway and by services that serve a single workspace.
- **Global handlers** — notify when a config type changes in any
workspace. The handler looks up its registered type in the changes
dict and gets the full list of affected workspaces. Used by shared
processing services (prompt-rag, agent manager, etc.) that serve all
workspaces. Each workspace in the list tells the handler which cache
entry to update rather than reloading everything.
#### Per-workspace config caching
Shared services that handle messages from multiple workspaces maintain a
per-workspace config cache. When a message arrives, the service looks up
the config for the workspace identified in `Metadata.user`. If the
workspace is not yet cached, the service fetches its config on demand.
Config change notifications update the relevant cache entry.
### Flow and queue isolation
Flows are workspace-owned. When two workspaces start flows with the same
name and blueprint, their queues must be separate to prevent data
mixing.
Flow blueprint templates currently use `{id}` (flow instance ID) and
`{class}` (blueprint name) as template variables in queue names. A new
`{workspace}` variable is added so queue names include the workspace:
**Current queue names (no workspace isolation):**
```
flow:tg:document-load:{id} → flow:tg:document-load:default
request:tg:embeddings:{class} → request:tg:embeddings:everything
```
**With workspace isolation:**
```
flow:tg:{workspace}:document-load:{id} → flow:tg:ws-a:document-load:default
request:tg:{workspace}:embeddings:{class} → request:tg:ws-a:embeddings:everything
```
The flow service substitutes `{workspace}` from the authenticated
workspace when starting a flow, the same way it substitutes `{id}` and
`{class}` today.
Processing services are shared infrastructure — they consume from
workspace-specific queues but are not themselves workspace-aware. The
workspace is carried in `Metadata.user` on every message, so services
know which workspace's data they are processing.
Blueprint templates need updating to include `{workspace}` in all queue
name patterns. For migration, the flow service can inject the workspace
into queue names automatically if the template does not include
`{workspace}`, defaulting to the legacy behaviour for existing
blueprints.
See [flow-class-definition.md](flow-class-definition.md) for the full
blueprint template specification.
### What changes and what doesn't
**Changes:**
| Component | Change |
|-----------|--------|
| `gateway/auth.py` | Replace `Authenticator` with new auth middleware |
| `gateway/service.py` | Initialise IAM client, configure JWT validation |
| `gateway/endpoint/*.py` | Add role requirement per endpoint |
| Metadata propagation | Gateway sets `user` from workspace, ignores query param |
| Config service | Add workspace dimension to config schema |
| Config table | `PRIMARY KEY ((workspace, class), key)` |
| Config request/response schema | Add `workspace` field |
| Config notify messages | Include workspace ID in change notifications |
| `on_config` handlers | Support workspace-scoped and global modes |
| Shared services | Per-workspace config caching |
| Flow blueprints | Add `{workspace}` template variable to queue names |
| Flow service | Substitute `{workspace}` when starting flows |
| CLI tools | New user management commands, `--api-key` argument |
| Cassandra schema | New `iam_workspaces`, `iam_users`, `iam_api_keys` tables |
**Does not change:**
| Component | Reason |
|-----------|--------|
| Internal service-to-service pub/sub | Services trust the gateway |
| `Metadata` dataclass | `user` field continues to carry workspace identity |
| Storage-layer isolation | Same `user` + `collection` filtering |
| Message serialisation | No schema changes |
### Migration
This is a breaking change. Existing deployments must be reconfigured:
1. `GATEWAY_SECRET` is removed. Authentication requires API keys or
JWT login tokens.
2. The `?user=` query parameter is removed. Workspace identity comes
from authentication.
3. On first start, the IAM service bootstraps a default workspace and
admin user. The initial API key is output to the service log.
4. Operators create additional workspaces and users via CLI tools.
5. Flow blueprints must be updated to include `{workspace}` in queue
name patterns.
6. Config data must be migrated to include the workspace dimension.
## Extension points
The design includes deliberate extension points for future capabilities.
These are not implemented but the architecture does not preclude them:
- **Multi-workspace access.** Users could be granted access to
additional workspaces beyond their primary assignment. The workspace
validation step checks a grant list instead of a single assignment.
- **Rules-based access control.** A separate access control service
could evaluate fine-grained policies (per-collection permissions,
operation-level restrictions, time-based access). The gateway
delegates authorisation decisions to this service.
- **External identity provider integration.** SAML, LDAP, and OIDC
flows (group mapping, claims-based role assignment) could be added
to the IAM service.
- **Cross-workspace administration.** A `superadmin` role for platform
operators who manage multiple workspaces.
- **Delegated workspace provisioning.** APIs for programmatic workspace
creation and user onboarding.
These extensions are additive — they extend the validation logic
without changing the request/response protocol. The gateway can be
replaced with an alternative implementation that supports these
capabilities while the IAM service and backend services remain
unchanged.
## Implementation plan
Workspace support is a prerequisite for auth — users are assigned to
workspaces, config is workspace-scoped, and flows use workspace in
queue names. Implementing workspaces first allows the structural changes
to be tested end-to-end without auth complicating debugging.
### Phase 1: Workspace support (no auth)
All workspace-scoped data and processing changes. The system works with
workspaces but no authentication — callers pass workspace as a
parameter, honour system. This allows full end-to-end testing: multiple
workspaces with separate flows, config, queues, and data.
#### Config service
- Update config client API to accept a workspace parameter on all
requests
- Update config storage schema to add workspace as a key dimension
- Update config notification API to report changes as a dict of
type → workspace list
- Update the processor base class to understand workspaces in config
notifications (workspace-scoped and global handler modes)
- Update all processors to implement workspace-aware config handling
(per-workspace config caching, on-demand fetch)
#### Flow and queue isolation
- Update flow blueprints to include `{workspace}` in all queue name
patterns
- Update the flow service to substitute `{workspace}` when starting
flows
- Update all built-in blueprints to include `{workspace}`
#### CLI tools (workspace support)
- Add `--workspace` argument to CLI tools that operate on
workspace-scoped data
- Add `tg-create-workspace`, `tg-list-workspaces` commands
### Phase 2: Authentication and access control
With workspaces working, add the IAM service and lock down the gateway.
#### IAM service
A new service handling identity and access management on behalf of the
API gateway:
- Add workspace table support (CRUD, enable/disable)
- Add user table support (CRUD, enable/disable, workspace assignment)
- Add roles support (role assignment, role validation)
- Add API key support (create, revoke, list, hash storage)
- Add ability to initialise a JWT signing key for token grants
- Add token grant endpoint: user/password login returns a signed JWT
- Add bootstrap/initialisation mechanism: ability to set the signing
key and create the initial workspace + admin user on first start
#### API gateway integration
- Add IAM middleware to the API gateway replacing the current
`Authenticator`
- Add local JWT validation (public key from IAM service)
- Add API key resolution with local cache (hash → user/workspace/roles,
cache miss calls IAM service, short TTL)
- Add login endpoint forwarding to IAM service
- Add workspace resolution: validate requested workspace against user
assignment
- Add role-based endpoint access checks
- Add user management API endpoints (forwarded to IAM service)
- Add audit logging (user ID, workspace, endpoint, method, status)
- WebSocket auth via first-message protocol (auth message after
connect, socket stays open on failure, re-auth supported)
#### CLI tools (auth support)
- Add `tg-create-user`, `tg-list-users`, `tg-disable-user` commands
- Add `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key`
commands
- Replace `--api-token` with `--api-key` on existing CLI tools
#### Bootstrap and cutover
- Create default workspace and admin user on first start if IAM tables
are empty
- Remove `GATEWAY_SECRET` and `?user=` query parameter support
## Design Decisions
### IAM data store
IAM data is stored in dedicated Cassandra tables owned by the IAM
service, not in the config service. Reasons:
- **Security isolation.** The config service has a broad, generic
protocol. An access control failure on the config service could
expose credentials. A dedicated IAM service with a purpose-built
protocol limits the attack surface and makes security auditing
clearer.
- **Data model fit.** IAM needs indexed lookups (API key hash → user,
list keys by user). The config service's `(workspace, type, key) →
value` model stores opaque JSON strings with no secondary indexes.
- **Scope.** IAM data is global (workspaces, users, keys). Config is
workspace-scoped. Mixing global and workspace-scoped data in the
same store adds complexity.
- **Audit.** IAM operations (key creation, revocation, login attempts)
are security events that should be logged separately from general
config changes.
## Deferred to future design
- **OIDC integration.** External identity provider support (SAML, LDAP,
OIDC) is left for future implementation. The extension points section
describes where this fits architecturally.
- **API key scoping.** API keys could be scoped to specific collections
within a workspace rather than granting workspace-wide access. To be
designed when the need arises.
- **tg-init-trustgraph** only initialises a single workspace.
## References
- [Data Ownership and Information Separation](data-ownership-model.md)
- [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md)
- [Multi-Tenant Support Specification](multi-tenant-support.md)
- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md)

View file

@ -1,8 +0,0 @@
name: user
in: query
required: false
schema:
type: string
default: trustgraph
description: User identifier
example: alice

View file

@ -43,15 +43,6 @@ properties:
type: string type: string
description: Result of the action description: Result of the action
example: "Paris is the capital of France" example: "Paris is the capital of France"
user:
type: string
description: User context for this step
example: alice
user:
type: string
description: User identifier for multi-tenancy
default: trustgraph
example: alice
streaming: streaming:
type: boolean type: boolean
description: Enable streaming response delivery description: Enable streaming response delivery

View file

@ -14,14 +14,9 @@ properties:
- delete-collection - delete-collection
description: | description: |
Collection operation: Collection operation:
- `list-collections`: List collections for user - `list-collections`: List collections in workspace
- `update-collection`: Create or update collection metadata - `update-collection`: Create or update collection metadata
- `delete-collection`: Delete collection - `delete-collection`: Delete collection
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection identifier (for update, delete) description: Collection identifier (for update, delete)

View file

@ -12,13 +12,8 @@ properties:
items: items:
type: object type: object
required: required:
- user
- collection - collection
properties: properties:
user:
type: string
description: User identifier
example: alice
collection: collection:
type: string type: string
description: Collection identifier description: Collection identifier

View file

@ -17,11 +17,6 @@ properties:
minimum: 1 minimum: 1
maximum: 1000 maximum: 1000
example: 20 example: 20
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to search description: Collection to search

View file

@ -17,11 +17,6 @@ properties:
minimum: 1 minimum: 1
maximum: 1000 maximum: 1000
example: 20 example: 20
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to search description: Collection to search

View file

@ -27,11 +27,6 @@ properties:
minimum: 1 minimum: 1
maximum: 1000 maximum: 1000
example: 20 example: 20
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to search description: Collection to search

View file

@ -18,17 +18,12 @@ properties:
- unload-kg-core - unload-kg-core
description: | description: |
Knowledge core operation: Knowledge core operation:
- `list-kg-cores`: List knowledge cores for user - `list-kg-cores`: List knowledge cores in workspace
- `get-kg-core`: Get knowledge core by ID - `get-kg-core`: Get knowledge core by ID
- `put-kg-core`: Store triples and/or embeddings - `put-kg-core`: Store triples and/or embeddings
- `delete-kg-core`: Delete knowledge core by ID - `delete-kg-core`: Delete knowledge core by ID
- `load-kg-core`: Load knowledge core into flow - `load-kg-core`: Load knowledge core into flow
- `unload-kg-core`: Unload knowledge core from flow - `unload-kg-core`: Unload knowledge core from flow
user:
type: string
description: User identifier (for list-kg-cores, put-kg-core, delete-kg-core)
default: trustgraph
example: alice
id: id:
type: string type: string
description: Knowledge core ID (for get, put, delete, load, unload) description: Knowledge core ID (for get, put, delete, load, unload)
@ -53,17 +48,12 @@ properties:
type: object type: object
required: required:
- id - id
- user
- collection - collection
properties: properties:
id: id:
type: string type: string
description: Knowledge core ID description: Knowledge core ID
example: core-123 example: core-123
user:
type: string
description: User identifier
example: alice
collection: collection:
type: string type: string
description: Collection identifier description: Collection identifier
@ -89,17 +79,12 @@ properties:
type: object type: object
required: required:
- id - id
- user
- collection - collection
properties: properties:
id: id:
type: string type: string
description: Knowledge core ID description: Knowledge core ID
example: core-123 example: core-123
user:
type: string
description: User identifier
example: alice
collection: collection:
type: string type: string
description: Collection identifier description: Collection identifier

View file

@ -15,17 +15,12 @@ properties:
type: object type: object
required: required:
- id - id
- user
- collection - collection
properties: properties:
id: id:
type: string type: string
description: Knowledge core ID description: Knowledge core ID
example: core-123 example: core-123
user:
type: string
description: User identifier
example: alice
collection: collection:
type: string type: string
description: Collection identifier description: Collection identifier
@ -48,17 +43,12 @@ properties:
type: object type: object
required: required:
- id - id
- user
- collection - collection
properties: properties:
id: id:
type: string type: string
description: Knowledge core ID description: Knowledge core ID
example: core-123 example: core-123
user:
type: string
description: User identifier
example: alice
collection: collection:
type: string type: string
description: Collection identifier description: Collection identifier

View file

@ -62,11 +62,6 @@ properties:
description: Collection identifier description: Collection identifier
default: default default: default
example: default example: default
user:
type: string
description: User identifier
default: trustgraph
example: alice
document-id: document-id:
type: string type: string
description: Document identifier description: Document identifier

View file

@ -15,11 +15,6 @@ properties:
type: string type: string
description: Document identifier description: Document identifier
example: doc-456 example: doc-456
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection for document description: Collection for document

View file

@ -14,11 +14,6 @@ properties:
type: string type: string
description: Document identifier description: Document identifier
example: doc-123 example: doc-123
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection for document description: Collection for document

View file

@ -28,11 +28,6 @@ properties:
type: string type: string
description: Operation name (for multi-operation documents) description: Operation name (for multi-operation documents)
example: GetPerson example: GetPerson
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to query description: Collection to query

View file

@ -10,11 +10,6 @@ properties:
type: string type: string
description: Natural language question description: Natural language question
example: Who does Alice know that works in engineering? example: Who does Alice know that works in engineering?
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to query description: Collection to query

View file

@ -18,11 +18,6 @@ properties:
minimum: 1 minimum: 1
maximum: 100000 maximum: 100000
example: 100 example: 100
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to query description: Collection to query

View file

@ -9,11 +9,6 @@ properties:
type: string type: string
description: User query or question description: User query or question
example: What are the key findings in the research papers? example: What are the key findings in the research papers?
user:
type: string
description: User identifier for multi-tenancy
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to search within description: Collection to search within

View file

@ -9,11 +9,6 @@ properties:
type: string type: string
description: User query or question description: User query or question
example: What connections exist between quantum physics and computer science? example: What connections exist between quantum physics and computer science?
user:
type: string
description: User identifier for multi-tenancy
default: trustgraph
example: alice
collection: collection:
type: string type: string
description: Collection to search within description: Collection to search within

View file

@ -10,11 +10,10 @@ post:
Collections are organizational units for grouping: Collections are organizational units for grouping:
- Documents in the librarian - Documents in the librarian
- Knowledge cores - Knowledge cores
- User data - Workspace data
Each collection has: Each collection has:
- **user**: Owner identifier - **collection**: Unique collection ID (within the workspace)
- **collection**: Unique collection ID
- **name**: Human-readable display name - **name**: Human-readable display name
- **description**: Purpose and contents - **description**: Purpose and contents
- **tags**: Labels for filtering and organization - **tags**: Labels for filtering and organization
@ -22,7 +21,7 @@ post:
## Operations ## Operations
### list-collections ### list-collections
List all collections for a user. Optionally filter by tags and limit results. List all collections in the workspace. Optionally filter by tags and limit results.
Returns array of collection metadata. Returns array of collection metadata.
### update-collection ### update-collection
@ -30,7 +29,7 @@ post:
If it exists, metadata is updated. Allows setting name, description, and tags. If it exists, metadata is updated. Allows setting name, description, and tags.
### delete-collection ### delete-collection
Delete a collection by user and collection ID. This removes the metadata but Delete a collection by collection ID. This removes the metadata but
typically does not delete the associated data (documents, knowledge cores). typically does not delete the associated data (documents, knowledge cores).
operationId: collectionManagementService operationId: collectionManagementService
@ -44,22 +43,19 @@ post:
$ref: '../components/schemas/collection/CollectionRequest.yaml' $ref: '../components/schemas/collection/CollectionRequest.yaml'
examples: examples:
listCollections: listCollections:
summary: List all collections for user summary: List all collections in workspace
value: value:
operation: list-collections operation: list-collections
user: alice
listCollectionsFiltered: listCollectionsFiltered:
summary: List collections filtered by tags summary: List collections filtered by tags
value: value:
operation: list-collections operation: list-collections
user: alice
tag-filter: ["research", "AI"] tag-filter: ["research", "AI"]
limit: 50 limit: 50
updateCollection: updateCollection:
summary: Create/update collection summary: Create/update collection
value: value:
operation: update-collection operation: update-collection
user: alice
collection: research collection: research
name: Research Papers name: Research Papers
description: Academic research papers on AI and ML description: Academic research papers on AI and ML
@ -69,7 +65,6 @@ post:
summary: Delete collection summary: Delete collection
value: value:
operation: delete-collection operation: delete-collection
user: alice
collection: research collection: research
responses: responses:
'200': '200':
@ -84,13 +79,11 @@ post:
value: value:
timestamp: "2024-01-15T10:30:00Z" timestamp: "2024-01-15T10:30:00Z"
collections: collections:
- user: alice - collection: research
collection: research
name: Research Papers name: Research Papers
description: Academic research papers on AI and ML description: Academic research papers on AI and ML
tags: ["research", "AI", "academic"] tags: ["research", "AI", "academic"]
- user: alice - collection: personal
collection: personal
name: Personal Documents name: Personal Documents
description: Personal notes and documents description: Personal notes and documents
tags: ["personal"] tags: ["personal"]

View file

@ -8,7 +8,6 @@ get:
## Parameters ## Parameters
- `user`: User identifier (required)
- `document-id`: Document IRI to retrieve (required) - `document-id`: Document IRI to retrieve (required)
- `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB) - `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB)
@ -16,13 +15,6 @@ get:
security: security:
- bearerAuth: [] - bearerAuth: []
parameters: parameters:
- name: user
in: query
required: true
schema:
type: string
description: User identifier
example: trustgraph
- name: document-id - name: document-id
in: query in: query
required: true required: true

View file

@ -23,7 +23,6 @@ get:
"m": { // Metadata "m": { // Metadata
"i": "core-id", // Knowledge core ID "i": "core-id", // Knowledge core ID
"m": [...], // Metadata triples array "m": [...], // Metadata triples array
"u": "user", // User
"c": "collection" // Collection "c": "collection" // Collection
}, },
"t": [...] // Triples array "t": [...] // Triples array
@ -36,7 +35,6 @@ get:
"m": { // Metadata "m": { // Metadata
"i": "core-id", "i": "core-id",
"m": [...], "m": [...],
"u": "user",
"c": "collection" "c": "collection"
}, },
"e": [ // Entities array "e": [ // Entities array
@ -56,7 +54,6 @@ get:
## Query Parameters ## Query Parameters
- **id**: Knowledge core ID to export - **id**: Knowledge core ID to export
- **user**: User identifier
## Streaming ## Streaming
@ -86,13 +83,6 @@ get:
type: string type: string
description: Knowledge core ID to export description: Knowledge core ID to export
example: core-123 example: core-123
- name: user
in: query
required: true
schema:
type: string
description: User identifier
example: alice
responses: responses:
'200': '200':
description: Export stream description: Export stream

View file

@ -69,25 +69,21 @@ post:
summary: Simple question summary: Simple question
value: value:
question: What is the capital of France? question: What is the capital of France?
user: alice
streamingQuestion: streamingQuestion:
summary: Question with streaming enabled summary: Question with streaming enabled
value: value:
question: Explain quantum computing question: Explain quantum computing
user: alice
streaming: true streaming: true
conversationWithHistory: conversationWithHistory:
summary: Multi-turn conversation summary: Multi-turn conversation
value: value:
question: And what about its population? question: And what about its population?
user: alice
history: history:
- thought: User is asking about the capital of France - thought: User is asking about the capital of France
action: search action: search
arguments: arguments:
query: "capital of France" query: "capital of France"
observation: "Paris is the capital of France" observation: "Paris is the capital of France"
user: alice
responses: responses:
'200': '200':
description: Successful response description: Successful response

View file

@ -75,7 +75,6 @@ post:
value: value:
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
limit: 10 limit: 10
user: alice
collection: research collection: research
largeQuery: largeQuery:
summary: Larger result set summary: Larger result set

View file

@ -88,14 +88,12 @@ post:
value: value:
data: JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg== data: JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg==
id: doc-789 id: doc-789
user: alice
collection: research collection: research
withMetadata: withMetadata:
summary: Load with metadata summary: Load with metadata
value: value:
data: JVBERi0xLjQKJeLjz9MK... data: JVBERi0xLjQKJeLjz9MK...
id: doc-101112 id: doc-101112
user: bob
collection: papers collection: papers
metadata: metadata:
- s: {v: "doc-101112", e: false} - s: {v: "doc-101112", e: false}

View file

@ -40,7 +40,6 @@ post:
- Higher = more context but slower - Higher = more context but slower
- Lower = faster but may miss relevant info - Lower = faster but may miss relevant info
- **collection**: Target specific document collection - **collection**: Target specific document collection
- **user**: Multi-tenant isolation
operationId: documentRagService operationId: documentRagService
security: security:
@ -64,13 +63,11 @@ post:
summary: Basic document query summary: Basic document query
value: value:
query: What are the key findings in the research papers? query: What are the key findings in the research papers?
user: alice
collection: research collection: research
streamingQuery: streamingQuery:
summary: Streaming query summary: Streaming query
value: value:
query: Summarize the main conclusions query: Summarize the main conclusions
user: alice
collection: research collection: research
doc-limit: 15 doc-limit: 15
streaming: true streaming: true

View file

@ -66,7 +66,6 @@ post:
value: value:
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
limit: 10 limit: 10
user: alice
collection: research collection: research
largeQuery: largeQuery:
summary: Larger result set summary: Larger result set

View file

@ -77,13 +77,11 @@ post:
summary: Basic graph query summary: Basic graph query
value: value:
query: What connections exist between quantum physics and computer science? query: What connections exist between quantum physics and computer science?
user: alice
collection: research collection: research
streamingQuery: streamingQuery:
summary: Streaming query with custom limits summary: Streaming query with custom limits
value: value:
query: Trace the historical development of AI from Turing to modern LLMs query: Trace the historical development of AI from Turing to modern LLMs
user: alice
collection: research collection: research
entity-limit: 40 entity-limit: 40
triple-limit: 25 triple-limit: 25

View file

@ -62,7 +62,6 @@ post:
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
schema_name: customers schema_name: customers
limit: 10 limit: 10
user: alice
collection: sales collection: sales
filteredQuery: filteredQuery:
summary: Search specific index summary: Search specific index

View file

@ -89,7 +89,6 @@ post:
email email
} }
} }
user: alice
collection: research collection: research
queryWithVariables: queryWithVariables:
summary: Query with variables summary: Query with variables

View file

@ -61,10 +61,6 @@ post:
query: query:
type: string type: string
description: SPARQL 1.1 query string description: SPARQL 1.1 query string
user:
type: string
default: trustgraph
description: User/keyspace identifier
collection: collection:
type: string type: string
default: default default: default
@ -78,7 +74,6 @@ post:
summary: SELECT query summary: SELECT query
value: value:
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
user: trustgraph
collection: default collection: default
askQuery: askQuery:
summary: ASK query summary: ASK query

View file

@ -79,13 +79,11 @@ post:
summary: Simple relationship question summary: Simple relationship question
value: value:
question: Who does Alice know? question: Who does Alice know?
user: alice
collection: research collection: research
complexQuestion: complexQuestion:
summary: Complex multi-hop question summary: Complex multi-hop question
value: value:
question: What companies employ engineers that Bob collaborates with? question: What companies employ engineers that Bob collaborates with?
user: bob
collection: work collection: work
filterQuestion: filterQuestion:
summary: Question with implicit filters summary: Question with implicit filters

View file

@ -87,14 +87,12 @@ post:
value: value:
text: This is the document text... text: This is the document text...
id: doc-123 id: doc-123
user: alice
collection: research collection: research
withMetadata: withMetadata:
summary: Load with RDF metadata using base64 text summary: Load with RDF metadata using base64 text
value: value:
text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u
id: doc-456 id: doc-456
user: alice
collection: research collection: research
metadata: metadata:
- s: {v: "doc-456", e: false} - s: {v: "doc-456", e: false}

View file

@ -81,7 +81,6 @@ post:
s: s:
v: https://example.com/person/alice v: https://example.com/person/alice
e: true e: true
user: alice
collection: research collection: research
limit: 100 limit: 100
allInstancesOfType: allInstancesOfType:
@ -100,7 +99,6 @@ post:
p: p:
v: https://example.com/knows v: https://example.com/knows
e: true e: true
user: alice
limit: 200 limit: 200
responses: responses:
'200': '200':

View file

@ -23,7 +23,6 @@ post:
"m": { // Metadata "m": { // Metadata
"i": "core-id", // Knowledge core ID "i": "core-id", // Knowledge core ID
"m": [...], // Metadata triples array "m": [...], // Metadata triples array
"u": "user", // User
"c": "collection" // Collection "c": "collection" // Collection
}, },
"t": [...] // Triples array "t": [...] // Triples array
@ -36,7 +35,6 @@ post:
"m": { // Metadata "m": { // Metadata
"i": "core-id", "i": "core-id",
"m": [...], "m": [...],
"u": "user",
"c": "collection" "c": "collection"
}, },
"e": [ // Entities array "e": [ // Entities array
@ -51,7 +49,6 @@ post:
## Query Parameters ## Query Parameters
- **id**: Knowledge core ID - **id**: Knowledge core ID
- **user**: User identifier
## Streaming ## Streaming
@ -77,13 +74,6 @@ post:
type: string type: string
description: Knowledge core ID to import description: Knowledge core ID to import
example: core-123 example: core-123
- name: user
in: query
required: true
schema:
type: string
description: User identifier
example: alice
requestBody: requestBody:
required: true required: true
content: content:

View file

@ -12,12 +12,12 @@ post:
- **Graph Embeddings**: Vector embeddings for entities - **Graph Embeddings**: Vector embeddings for entities
- **Metadata**: Descriptive information about the knowledge - **Metadata**: Descriptive information about the knowledge
Each core has an ID, user, and collection for organization. Each core has an ID and collection for organization (within the workspace).
## Operations ## Operations
### list-kg-cores ### list-kg-cores
List all knowledge cores for a user. Returns array of core IDs. List all knowledge cores in the workspace. Returns array of core IDs.
### get-kg-core ### get-kg-core
Retrieve a knowledge core by ID. Returns triples and/or graph embeddings. Retrieve a knowledge core by ID. Returns triples and/or graph embeddings.
@ -58,7 +58,6 @@ post:
summary: List knowledge cores summary: List knowledge cores
value: value:
operation: list-kg-cores operation: list-kg-cores
user: alice
getKnowledgeCore: getKnowledgeCore:
summary: Get knowledge core summary: Get knowledge core
value: value:
@ -71,7 +70,6 @@ post:
triples: triples:
metadata: metadata:
id: core-123 id: core-123
user: alice
collection: default collection: default
metadata: metadata:
- s: {v: "https://example.com/core-123", e: true} - s: {v: "https://example.com/core-123", e: true}
@ -91,7 +89,6 @@ post:
graph-embeddings: graph-embeddings:
metadata: metadata:
id: core-123 id: core-123
user: alice
collection: default collection: default
metadata: [] metadata: []
entities: entities:
@ -106,7 +103,6 @@ post:
triples: triples:
metadata: metadata:
id: core-456 id: core-456
user: bob
collection: research collection: research
metadata: [] metadata: []
triples: triples:
@ -116,7 +112,6 @@ post:
graph-embeddings: graph-embeddings:
metadata: metadata:
id: core-456 id: core-456
user: bob
collection: research collection: research
metadata: [] metadata: []
entities: entities:
@ -127,7 +122,6 @@ post:
value: value:
operation: delete-kg-core operation: delete-kg-core
id: core-123 id: core-123
user: alice
loadKnowledgeCore: loadKnowledgeCore:
summary: Load core into flow summary: Load core into flow
value: value:
@ -161,7 +155,6 @@ post:
triples: triples:
metadata: metadata:
id: core-123 id: core-123
user: alice
collection: default collection: default
metadata: metadata:
- s: {v: "https://example.com/core-123", e: true} - s: {v: "https://example.com/core-123", e: true}
@ -177,7 +170,6 @@ post:
graph-embeddings: graph-embeddings:
metadata: metadata:
id: core-123 id: core-123
user: alice
collection: default collection: default
metadata: [] metadata: []
entities: entities:

View file

@ -26,5 +26,4 @@ examples:
vectors: [0.023, -0.142, 0.089, 0.234] vectors: [0.023, -0.142, 0.089, 0.234]
schema_name: customers schema_name: customers
limit: 10 limit: 10
user: trustgraph
collection: default collection: default

View file

@ -24,10 +24,6 @@ properties:
query: query:
type: string type: string
description: SPARQL 1.1 query string description: SPARQL 1.1 query string
user:
type: string
default: trustgraph
description: User/keyspace identifier
collection: collection:
type: string type: string
default: default default: default
@ -42,5 +38,4 @@ examples:
flow: my-flow flow: my-flow
request: request:
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
user: trustgraph
collection: default collection: default

View file

@ -72,7 +72,6 @@ def sample_message_data():
}, },
"DocumentRagQuery": { "DocumentRagQuery": {
"query": "What is artificial intelligence?", "query": "What is artificial intelligence?",
"user": "test_user",
"collection": "test_collection", "collection": "test_collection",
"doc_limit": 10 "doc_limit": 10
}, },
@ -95,7 +94,6 @@ def sample_message_data():
}, },
"Metadata": { "Metadata": {
"id": "test-doc-123", "id": "test-doc-123",
"user": "test_user",
"collection": "test_collection" "collection": "test_collection"
}, },
"Term": { "Term": {
@ -130,9 +128,8 @@ def invalid_message_data():
{}, # Missing required fields {}, # Missing required fields
], ],
"DocumentRagQuery": [ "DocumentRagQuery": [
{"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query {"query": None, "collection": "test", "doc_limit": 10}, # Invalid query
{"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user {"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
{"query": "test"}, # Missing required fields {"query": "test"}, # Missing required fields
], ],
"Term": [ "Term": [

View file

@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_schema_fields(self): def test_request_schema_fields(self):
"""Test that DocumentEmbeddingsRequest has expected fields""" """Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vector=[0.1, 0.2, 0.3], vector=[0.1, 0.2, 0.3],
limit=10, limit=10,
user="test_user",
collection="test_collection" collection="test_collection"
) )
# Verify all expected fields exist
assert hasattr(request, 'vector') assert hasattr(request, 'vector')
assert hasattr(request, 'limit') assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection') assert hasattr(request, 'collection')
# Verify field values
assert request.vector == [0.1, 0.2, 0.3] assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10 assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection" assert request.collection == "test_collection"
def test_request_translator_decode(self): def test_request_translator_decode(self):
@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract:
data = { data = {
"vector": [0.1, 0.2, 0.3, 0.4], "vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5, "limit": 5,
"user": "custom_user",
"collection": "custom_collection" "collection": "custom_collection"
} }
@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5 assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection" assert result.collection == "custom_collection"
def test_request_translator_decode_with_defaults(self): def test_request_translator_decode_with_defaults(self):
@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract:
data = { data = {
"vector": [0.1, 0.2] "vector": [0.1, 0.2]
# No limit, user, or collection provided # No limit or collection provided
} }
result = translator.decode(data) result = translator.decode(data)
@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2] assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default assert result.collection == "default" # Default
def test_request_translator_encode(self): def test_request_translator_encode(self):
@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract:
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vector=[0.5, 0.6], vector=[0.5, 0.6],
limit=20, limit=20,
user="test_user",
collection="test_collection" collection="test_collection"
) )
@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["vector"] == [0.5, 0.6] assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20 assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection" assert result["collection"] == "test_collection"
@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
request_data = { request_data = {
"vector": [0.1, 0.2, 0.3], "vector": [0.1, 0.2, 0.3],
"limit": 5, "limit": 5,
"user": "test_user",
"collection": "test_collection" "collection": "test_collection"
} }

View file

@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts:
# Test required fields # Test required fields
query = DocumentRagQuery(**query_data) query = DocumentRagQuery(**query_data)
assert hasattr(query, 'query') assert hasattr(query, 'query')
assert hasattr(query, 'user')
assert hasattr(query, 'collection') assert hasattr(query, 'collection')
assert hasattr(query, 'doc_limit') assert hasattr(query, 'doc_limit')
@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts:
# Test valid query # Test valid query
valid_query = DocumentRagQuery( valid_query = DocumentRagQuery(
query="What is AI?", query="What is AI?",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=5 doc_limit=5
) )
assert valid_query.query == "What is AI?" assert valid_query.query == "What is AI?"
assert valid_query.user == "test_user"
assert valid_query.collection == "test_collection" assert valid_query.collection == "test_collection"
assert valid_query.doc_limit == 5 assert valid_query.doc_limit == 5
@ -400,7 +397,6 @@ class TestMetadataMessageContracts:
metadata = Metadata(**metadata_data) metadata = Metadata(**metadata_data)
assert metadata.id == "test-doc-123" assert metadata.id == "test-doc-123"
assert metadata.user == "test_user"
assert metadata.collection == "test_collection" assert metadata.collection == "test_collection"
def test_error_schema_contract(self): def test_error_schema_contract(self):
@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts:
required_fields = { required_fields = {
"TextCompletionRequest": ["system", "prompt"], "TextCompletionRequest": ["system", "prompt"],
"TextCompletionResponse": ["error", "response", "model"], "TextCompletionResponse": ["error", "response", "model"],
"DocumentRagQuery": ["query", "user", "collection"], "DocumentRagQuery": ["query", "collection"],
"DocumentRagResponse": ["error", "response"], "DocumentRagResponse": ["error", "response"],
"AgentRequest": ["question", "history"], "AgentRequest": ["question", "history"],
"AgentResponse": ["error"], "AgentResponse": ["error"],

View file

@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts:
def test_agent_request_orchestration_fields_roundtrip(self): def test_agent_request_orchestration_fields_roundtrip(self):
req = AgentRequest( req = AgentRequest(
question="Test question", question="Test question",
user="testuser",
collection="default", collection="default",
correlation_id="corr-123", correlation_id="corr-123",
parent_session_id="parent-sess", parent_session_id="parent-sess",
@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts:
def test_agent_request_orchestration_fields_default_empty(self): def test_agent_request_orchestration_fields_default_empty(self):
req = AgentRequest( req = AgentRequest(
question="Test question", question="Test question",
user="testuser",
) )
assert req.correlation_id == "" assert req.correlation_id == ""
@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract:
) )
req = AgentRequest( req = AgentRequest(
question="goal", question="goal",
user="testuser",
correlation_id="corr-123", correlation_id="corr-123",
history=[step], history=[step],
) )
@ -126,7 +123,6 @@ class TestSynthesisStepContract:
req = AgentRequest( req = AgentRequest(
question="Original question", question="Original question",
user="testuser",
pattern="supervisor", pattern="supervisor",
correlation_id="", correlation_id="",
session_id="parent-sess", session_id="parent-sess",

View file

@ -22,7 +22,6 @@ class TestRowsCassandraContracts:
# Create test object with all required fields # Create test object with all required fields
test_metadata = Metadata( test_metadata = Metadata(
id="test-doc-001", id="test-doc-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -47,7 +46,6 @@ class TestRowsCassandraContracts:
# Verify metadata structure # Verify metadata structure
assert hasattr(test_object.metadata, 'id') assert hasattr(test_object.metadata, 'id')
assert hasattr(test_object.metadata, 'user')
assert hasattr(test_object.metadata, 'collection') assert hasattr(test_object.metadata, 'collection')
# Verify types # Verify types
@ -150,7 +148,6 @@ class TestRowsCassandraContracts:
original = ExtractedObject( original = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="serial-001", id="serial-001",
user="test_user",
collection="test_coll", collection="test_coll",
), ),
schema_name="test_schema", schema_name="test_schema",
@ -168,7 +165,6 @@ class TestRowsCassandraContracts:
# Verify round-trip # Verify round-trip
assert decoded.metadata.id == original.metadata.id assert decoded.metadata.id == original.metadata.id
assert decoded.metadata.user == original.metadata.user
assert decoded.metadata.collection == original.metadata.collection assert decoded.metadata.collection == original.metadata.collection
assert decoded.schema_name == original.schema_name assert decoded.schema_name == original.schema_name
assert decoded.values == original.values assert decoded.values == original.values
@ -228,8 +224,7 @@ class TestRowsCassandraContracts:
# Create test object # Create test object
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="meta-001", id="meta-001", # -> keyspace
user="user123", # -> keyspace
collection="coll456", # -> partition key collection="coll456", # -> partition key
), ),
schema_name="table789", # -> table name schema_name="table789", # -> table name
@ -242,7 +237,6 @@ class TestRowsCassandraContracts:
# - metadata.user -> Cassandra keyspace # - metadata.user -> Cassandra keyspace
# - schema_name -> Cassandra table # - schema_name -> Cassandra table
# - metadata.collection -> Part of primary key # - metadata.collection -> Part of primary key
assert test_obj.metadata.user # Required for keyspace
assert test_obj.schema_name # Required for table assert test_obj.schema_name # Required for table
assert test_obj.metadata.collection # Required for partition key assert test_obj.metadata.collection # Required for partition key
@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch:
# Create test object with multiple values in batch # Create test object with multiple values in batch
test_metadata = Metadata( test_metadata = Metadata(
id="batch-doc-001", id="batch-doc-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch:
"""Test empty batch ExtractedObject contract""" """Test empty batch ExtractedObject contract"""
test_metadata = Metadata( test_metadata = Metadata(
id="empty-batch-001", id="empty-batch-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch:
"""Test single-item batch (backward compatibility) contract""" """Test single-item batch (backward compatibility) contract"""
test_metadata = Metadata( test_metadata = Metadata(
id="single-batch-001", id="single-batch-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch:
original = ExtractedObject( original = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="batch-serial-001", id="batch-serial-001",
user="test_user",
collection="test_coll", collection="test_coll",
), ),
schema_name="test_schema", schema_name="test_schema",
@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch:
# Verify round-trip for batch # Verify round-trip for batch
assert decoded.metadata.id == original.metadata.id assert decoded.metadata.id == original.metadata.id
assert decoded.metadata.user == original.metadata.user
assert decoded.metadata.collection == original.metadata.collection assert decoded.metadata.collection == original.metadata.collection
assert decoded.schema_name == original.schema_name assert decoded.schema_name == original.schema_name
assert len(decoded.values) == len(original.values) assert len(decoded.values) == len(original.values)
@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch:
# 3. Be stored in the same keyspace (user) # 3. Be stored in the same keyspace (user)
test_metadata = Metadata( test_metadata = Metadata(
id="partition-test-001", id="partition-test-001", # Same keyspace
user="consistent_user", # Same keyspace
collection="consistent_collection", # Same partition collection="consistent_collection", # Same partition
) )
@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch:
) )
# Verify consistency contract # Verify consistency contract
assert batch_object.metadata.user # Must have user for keyspace
assert batch_object.metadata.collection # Must have collection for partition key assert batch_object.metadata.collection # Must have collection for partition key
# Verify unique primary keys in batch # Verify unique primary keys in batch

View file

@ -21,7 +21,6 @@ class TestRowsGraphQLQueryContracts:
"""Test RowsQueryRequest schema structure and required fields""" """Test RowsQueryRequest schema structure and required fields"""
# Create test request with all required fields # Create test request with all required fields
test_request = RowsQueryRequest( test_request = RowsQueryRequest(
user="test_user",
collection="test_collection", collection="test_collection",
query='{ customers { id name email } }', query='{ customers { id name email } }',
variables={"status": "active", "limit": "10"}, variables={"status": "active", "limit": "10"},
@ -29,21 +28,18 @@ class TestRowsGraphQLQueryContracts:
) )
# Verify all required fields are present # Verify all required fields are present
assert hasattr(test_request, 'user')
assert hasattr(test_request, 'collection') assert hasattr(test_request, 'collection')
assert hasattr(test_request, 'query') assert hasattr(test_request, 'query')
assert hasattr(test_request, 'variables') assert hasattr(test_request, 'variables')
assert hasattr(test_request, 'operation_name') assert hasattr(test_request, 'operation_name')
# Verify field types # Verify field types
assert isinstance(test_request.user, str)
assert isinstance(test_request.collection, str) assert isinstance(test_request.collection, str)
assert isinstance(test_request.query, str) assert isinstance(test_request.query, str)
assert isinstance(test_request.variables, dict) assert isinstance(test_request.variables, dict)
assert isinstance(test_request.operation_name, str) assert isinstance(test_request.operation_name, str)
# Verify content # Verify content
assert test_request.user == "test_user"
assert test_request.collection == "test_collection" assert test_request.collection == "test_collection"
assert "customers" in test_request.query assert "customers" in test_request.query
assert test_request.variables["status"] == "active" assert test_request.variables["status"] == "active"
@ -53,7 +49,6 @@ class TestRowsGraphQLQueryContracts:
"""Test RowsQueryRequest with minimal required fields""" """Test RowsQueryRequest with minimal required fields"""
# Create request with only essential fields # Create request with only essential fields
minimal_request = RowsQueryRequest( minimal_request = RowsQueryRequest(
user="user",
collection="collection", collection="collection",
query='{ test }', query='{ test }',
variables={}, variables={},
@ -61,7 +56,6 @@ class TestRowsGraphQLQueryContracts:
) )
# Verify minimal request is valid # Verify minimal request is valid
assert minimal_request.user == "user"
assert minimal_request.collection == "collection" assert minimal_request.collection == "collection"
assert minimal_request.query == '{ test }' assert minimal_request.query == '{ test }'
assert minimal_request.variables == {} assert minimal_request.variables == {}
@ -187,7 +181,6 @@ class TestRowsGraphQLQueryContracts:
"""Test that request/response can be serialized/deserialized correctly""" """Test that request/response can be serialized/deserialized correctly"""
# Create original request # Create original request
original_request = RowsQueryRequest( original_request = RowsQueryRequest(
user="serialization_test",
collection="test_data", collection="test_data",
query='{ orders(limit: 5) { id total customer { name } } }', query='{ orders(limit: 5) { id total customer { name } } }',
variables={"limit": "5", "status": "active"}, variables={"limit": "5", "status": "active"},
@ -202,7 +195,6 @@ class TestRowsGraphQLQueryContracts:
decoded_request = request_schema.decode(encoded_request) decoded_request = request_schema.decode(encoded_request)
# Verify request round-trip # Verify request round-trip
assert decoded_request.user == original_request.user
assert decoded_request.collection == original_request.collection assert decoded_request.collection == original_request.collection
assert decoded_request.query == original_request.query assert decoded_request.query == original_request.query
assert decoded_request.variables == original_request.variables assert decoded_request.variables == original_request.variables
@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts:
"""Test supported GraphQL query formats""" """Test supported GraphQL query formats"""
# Test basic query # Test basic query
basic_query = RowsQueryRequest( basic_query = RowsQueryRequest(
user="test", collection="test", query='{ customers { id } }', collection="test", query='{ customers { id } }',
variables={}, operation_name="" variables={}, operation_name=""
) )
assert "customers" in basic_query.query assert "customers" in basic_query.query
@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts:
# Test query with variables # Test query with variables
parameterized_query = RowsQueryRequest( parameterized_query = RowsQueryRequest(
user="test", collection="test", collection="test",
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }', query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
variables={"status": "active", "limit": "10"}, variables={"status": "active", "limit": "10"},
operation_name="GetCustomers" operation_name="GetCustomers"
@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts:
# Test complex nested query # Test complex nested query
nested_query = RowsQueryRequest( nested_query = RowsQueryRequest(
user="test", collection="test", collection="test",
query=''' query='''
{ {
customers(limit: 10) { customers(limit: 10) {
@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts:
# This test verifies the current contract, though ideally we'd support all JSON types # This test verifies the current contract, though ideally we'd support all JSON types
variables_test = RowsQueryRequest( variables_test = RowsQueryRequest(
user="test", collection="test", query='{ test }', collection="test", query='{ test }',
variables={ variables={
"string_var": "test_value", "string_var": "test_value",
"numeric_var": "123", # Numbers as strings due to Map(String()) limitation "numeric_var": "123", # Numbers as strings due to Map(String()) limitation
@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts:
def test_cassandra_context_fields_contract(self): def test_cassandra_context_fields_contract(self):
"""Test that request contains necessary fields for Cassandra operations""" """Test that request contains necessary fields for Cassandra operations"""
# Verify request has fields needed for Cassandra keyspace/table targeting # Verify request has fields needed for partition key targeting
request = RowsQueryRequest( request = RowsQueryRequest(
user="keyspace_name", # Maps to Cassandra keyspace
collection="partition_collection", # Used in partition key collection="partition_collection", # Used in partition key
query='{ objects { id } }', query='{ objects { id } }',
variables={}, operation_name="" variables={}, operation_name=""
) )
# These fields are required for proper Cassandra operations # Required for partition key
assert request.user # Required for keyspace identification assert request.collection
assert request.collection # Required for partition key
# Verify field naming follows TrustGraph patterns (matching other query services) # Verify field naming follows TrustGraph patterns (matching other query services)
# This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns assert hasattr(request, 'collection')
assert hasattr(request, 'user') # Same as TriplesQueryRequest.user
assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection
def test_graphql_extensions_contract(self): def test_graphql_extensions_contract(self):
"""Test GraphQL extensions field format and usage""" """Test GraphQL extensions field format and usage"""
@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts:
# Request to execute specific operation # Request to execute specific operation
multi_op_request = RowsQueryRequest( multi_op_request = RowsQueryRequest(
user="test", collection="test", collection="test",
query=multi_op_query, query=multi_op_query,
variables={}, variables={},
operation_name="GetCustomers" operation_name="GetCustomers"
@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts:
# Test single operation (operation_name optional) # Test single operation (operation_name optional)
single_op_request = RowsQueryRequest( single_op_request = RowsQueryRequest(
user="test", collection="test", collection="test",
query='{ customers { id } }', query='{ customers { id } }',
variables={}, operation_name="" variables={}, operation_name=""
) )

View file

@ -41,10 +41,11 @@ class TestSchemaFieldContracts:
def test_metadata_fields(self): def test_metadata_fields(self):
# NOTE: there is no `metadata` field. A previous regression # NOTE: there is no `metadata` field. A previous regression
# constructed Metadata(metadata=...) and crashed at runtime. # constructed Metadata(metadata=...) and crashed at runtime.
# `user` was also dropped in the workspace refactor — workspace
# now flows via flow.workspace, not via message payload.
assert _field_names(Metadata) == { assert _field_names(Metadata) == {
"id", "id",
"root", "root",
"user",
"collection", "collection",
} }

View file

@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="structured-data-001", id="structured-data-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="extracted-obj-001", id="extracted-obj-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="extracted-batch-001", id="extracted-batch-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="extracted-empty-001", id="extracted-empty-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="struct-embed-001", id="struct-embed-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts:
def test_structured_data_submission_serialization(self): def test_structured_data_submission_serialization(self):
"""Test StructuredDataSubmission serialization contract""" """Test StructuredDataSubmission serialization contract"""
# Arrange # Arrange
metadata = Metadata(id="test", user="user", collection="col") metadata = Metadata(id="test", collection="col")
submission_data = { submission_data = {
"metadata": metadata, "metadata": metadata,
"format": "json", "format": "json",
@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_serialization(self): def test_extracted_object_serialization(self):
"""Test ExtractedObject serialization contract""" """Test ExtractedObject serialization contract"""
# Arrange # Arrange
metadata = Metadata(id="test", user="user", collection="col") metadata = Metadata(id="test", collection="col")
object_data = { object_data = {
"metadata": metadata, "metadata": metadata,
"schema_name": "test_schema", "schema_name": "test_schema",
@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_batch_serialization(self): def test_extracted_object_batch_serialization(self):
"""Test ExtractedObject batch serialization contract""" """Test ExtractedObject batch serialization contract"""
# Arrange # Arrange
metadata = Metadata(id="test", user="user", collection="col") metadata = Metadata(id="test", collection="col")
batch_object_data = { batch_object_data = {
"metadata": metadata, "metadata": metadata,
"schema_name": "test_schema", "schema_name": "test_schema",
@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_empty_batch_serialization(self): def test_extracted_object_empty_batch_serialization(self):
"""Test ExtractedObject empty batch serialization contract""" """Test ExtractedObject empty batch serialization contract"""
# Arrange # Arrange
metadata = Metadata(id="test", user="user", collection="col") metadata = Metadata(id="test", collection="col")
empty_batch_data = { empty_batch_data = {
"metadata": metadata, "metadata": metadata,
"schema_name": "test_schema", "schema_name": "test_schema",

View file

@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config): async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
"""Test basic agent integration with structured query tool""" """Test basic agent integration with structured query tool"""
# Arrange - Load tool configuration # Arrange - Load tool configuration
await agent_processor.on_tools_config(structured_query_tool_config, "v1") await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Create agent request # Create agent request
request = AgentRequest( request = AgentRequest(
@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
state="", state="",
group=None, group=None,
history=[], history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -119,6 +118,7 @@ Args: {
# Mock flow parameter in agent_processor.on_request # Mock flow parameter in agent_processor.on_request
flow = MagicMock() flow = MagicMock()
flow.side_effect = flow_context flow.side_effect = flow_context
flow.workspace = "default"
# Act # Act
await agent_processor.on_request(msg, consumer, flow) await agent_processor.on_request(msg, consumer, flow)
@ -146,14 +146,13 @@ Args: {
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config): async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
"""Test agent handling of structured query errors""" """Test agent handling of structured query errors"""
# Arrange # Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1") await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest( request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.", question="Find data from a table that doesn't exist using structured query.",
state="", state="",
group=None, group=None,
history=[], history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -199,6 +198,7 @@ Args: {
flow = MagicMock() flow = MagicMock()
flow.side_effect = flow_context flow.side_effect = flow_context
flow.workspace = "default"
# Act # Act
await agent_processor.on_request(msg, consumer, flow) await agent_processor.on_request(msg, consumer, flow)
@ -221,14 +221,13 @@ Args: {
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
"""Test agent using structured query in multi-step reasoning""" """Test agent using structured query in multi-step reasoning"""
# Arrange # Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1") await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest( request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.", question="First find all customers from California, then tell me how many orders they have made.",
state="", state="",
group=None, group=None,
history=[], history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -279,6 +278,7 @@ Args: {
flow = MagicMock() flow = MagicMock()
flow.side_effect = flow_context flow.side_effect = flow_context
flow.workspace = "default"
# Act # Act
await agent_processor.on_request(msg, consumer, flow) await agent_processor.on_request(msg, consumer, flow)
@ -313,14 +313,13 @@ Args: {
} }
} }
await agent_processor.on_tools_config(tool_config_with_collection, "v1") await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
request = AgentRequest( request = AgentRequest(
question="Query the sales data for recent transactions.", question="Query the sales data for recent transactions.",
state="", state="",
group=None, group=None,
history=[], history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -371,6 +370,7 @@ Args: {
flow = MagicMock() flow = MagicMock()
flow.side_effect = flow_context flow.side_effect = flow_context
flow.workspace = "default"
# Act # Act
await agent_processor.on_request(msg, consumer, flow) await agent_processor.on_request(msg, consumer, flow)
@ -394,10 +394,10 @@ Args: {
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
"""Test that structured query tool arguments are properly validated""" """Test that structured query tool arguments are properly validated"""
# Arrange # Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1") await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Check that the tool was registered with correct arguments # Check that the tool was registered with correct arguments
tools = agent_processor.agent.tools tools = agent_processor.agents["default"].tools
assert "structured-query" in tools assert "structured-query" in tools
structured_tool = tools["structured-query"] structured_tool = tools["structured-query"]
@ -414,14 +414,13 @@ Args: {
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config): async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
"""Test that structured query results are properly formatted for agent consumption""" """Test that structured query results are properly formatted for agent consumption"""
# Arrange # Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1") await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest( request = AgentRequest(
question="Get customer information and format it nicely.", question="Get customer information and format it nicely.",
state="", state="",
group=None, group=None,
history=[], history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -482,6 +481,7 @@ Args: {
flow = MagicMock() flow = MagicMock()
flow.side_effect = flow_context flow.side_effect = flow_context
flow.workspace = "default"
# Act # Act
await agent_processor.on_request(msg, consumer, flow) await agent_processor.on_request(msg, consumer, flow)

View file

@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
# Create a mock message to trigger TrustGraph creation # Create a mock message to trigger TrustGraph creation
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_message.triples = [] mock_message.triples = []
# Mock collection_exists to return True # Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
# This should create TrustGraph with environment config # This should create TrustGraph with environment config
await processor.store_triples(mock_message) await processor.store_triples('test_user', mock_message)
# Verify Cluster was created with correct hosts # Verify Cluster was created with correct hosts
mock_cluster.assert_called_once() mock_cluster.assert_called_once()
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
# Trigger TrustGraph creation # Trigger TrustGraph creation
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_message.triples = [] mock_message.triples = []
# Mock collection_exists to return True # Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples('test_user', mock_message)
# Should use CLI parameters, not environment # Should use CLI parameters, not environment
mock_cluster.assert_called_once() mock_cluster.assert_called_once()
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
# Mock query to trigger TrustGraph creation # Mock query to trigger TrustGraph creation
mock_query = MagicMock() mock_query = MagicMock()
mock_query.user = 'default_user'
mock_query.collection = 'default_collection' mock_query.collection = 'default_collection'
mock_query.s = None mock_query.s = None
mock_query.p = None mock_query.p = None
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
mock_tg_instance.get_all.return_value = [] mock_tg_instance.get_all.return_value = []
processor.tg = mock_tg_instance processor.tg = mock_tg_instance
await processor.query_triples(mock_query) await processor.query_triples('default_user', mock_query)
# Should use defaults # Should use defaults
mock_cluster.assert_called_once() mock_cluster.assert_called_once()
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation # Trigger TrustGraph creation
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'legacy_user'
mock_message.metadata.collection = 'legacy_collection' mock_message.metadata.collection = 'legacy_collection'
mock_message.triples = [] mock_message.triples = []
# Mock collection_exists to return True # Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples('legacy_user', mock_message)
# Should use defaults since old parameters are not recognized # Should use defaults since old parameters are not recognized
mock_cluster.assert_called_once() mock_cluster.assert_called_once()
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation # Trigger TrustGraph creation
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'precedence_user'
mock_message.metadata.collection = 'precedence_collection' mock_message.metadata.collection = 'precedence_collection'
mock_message.triples = [] mock_message.triples = []
# Mock collection_exists to return True # Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples('precedence_user', mock_message)
# Should use new parameters, not old ones # Should use new parameters, not old ones
mock_cluster.assert_called_once() mock_cluster.assert_called_once()
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
# Trigger TrustGraph creation # Trigger TrustGraph creation
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'single_user'
mock_message.metadata.collection = 'single_collection' mock_message.metadata.collection = 'single_collection'
mock_message.triples = [] mock_message.triples = []
# Mock collection_exists to return True # Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples('single_user', mock_message)
# Single host should be converted to list # Single host should be converted to list
mock_cluster.assert_called_once() mock_cluster.assert_called_once()

View file

@ -115,7 +115,7 @@ class TestCassandraIntegration:
# Create test message # Create test message
storage_message = Triples( storage_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"), metadata=Metadata(collection="testcol"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri="http://example.org/person1"), s=Term(type=IRI, iri="http://example.org/person1"),
@ -178,7 +178,7 @@ class TestCassandraIntegration:
# Store test data for querying # Store test data for querying
query_test_message = Triples( query_test_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"), metadata=Metadata(collection="testcol"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri="http://example.org/alice"), s=Term(type=IRI, iri="http://example.org/alice"),
@ -212,7 +212,6 @@ class TestCassandraIntegration:
p=None, # None for wildcard p=None, # None for wildcard
o=None, # None for wildcard o=None, # None for wildcard
limit=10, limit=10,
user="testuser",
collection="testcol" collection="testcol"
) )
s_results = await query_processor.query_triples(s_query) s_results = await query_processor.query_triples(s_query)
@ -232,7 +231,6 @@ class TestCassandraIntegration:
p=Term(type=IRI, iri="http://example.org/knows"), p=Term(type=IRI, iri="http://example.org/knows"),
o=None, # None for wildcard o=None, # None for wildcard
limit=10, limit=10,
user="testuser",
collection="testcol" collection="testcol"
) )
p_results = await query_processor.query_triples(p_query) p_results = await query_processor.query_triples(p_query)
@ -259,7 +257,7 @@ class TestCassandraIntegration:
# Create multiple coroutines for concurrent storage # Create multiple coroutines for concurrent storage
async def store_person_data(person_id, name, age, department): async def store_person_data(person_id, name, age, department):
message = Triples( message = Triples(
metadata=Metadata(user="concurrent_test", collection="people"), metadata=Metadata(collection="people"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri=f"http://example.org/{person_id}"), s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
@ -329,7 +327,7 @@ class TestCassandraIntegration:
# Create a knowledge graph about a company # Create a knowledge graph about a company
company_graph = Triples( company_graph = Triples(
metadata=Metadata(user="integration_test", collection="company"), metadata=Metadata(collection="company"),
triples=[ triples=[
# People and their types # People and their types
Triple( Triple(

View file

@ -99,7 +99,6 @@ class TestDocumentRagIntegration:
# Act # Act
result = await document_rag.query( result = await document_rag.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
doc_limit=doc_limit doc_limit=doc_limit
) )
@ -110,7 +109,6 @@ class TestDocumentRagIntegration:
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit, limit=doc_limit,
user=user,
collection=collection collection=collection
) )
@ -278,14 +276,12 @@ class TestDocumentRagIntegration:
# Act # Act
await document_rag.query( await document_rag.query(
f"query from {user} in {collection}", f"query from {user} in {collection}",
user=user,
collection=collection collection=collection
) )
# Assert # Assert
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio @pytest.mark.asyncio
@ -353,6 +349,5 @@ class TestDocumentRagIntegration:
# Assert # Assert
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default" assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20 assert call_args.kwargs['limit'] == 20

View file

@ -107,7 +107,6 @@ class TestDocumentRagStreaming:
# Act # Act
result = await document_rag_streaming.query( result = await document_rag_streaming.query(
query=query, query=query,
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=10, doc_limit=10,
streaming=True, streaming=True,
@ -141,7 +140,6 @@ class TestDocumentRagStreaming:
# Act - Non-streaming # Act - Non-streaming
non_streaming_result = await document_rag_streaming.query( non_streaming_result = await document_rag_streaming.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
streaming=False streaming=False
@ -155,7 +153,6 @@ class TestDocumentRagStreaming:
streaming_result = await document_rag_streaming.query( streaming_result = await document_rag_streaming.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
streaming=True, streaming=True,
@ -178,7 +175,6 @@ class TestDocumentRagStreaming:
# Act # Act
result = await document_rag_streaming.query( result = await document_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=5, doc_limit=5,
streaming=True, streaming=True,
@ -200,7 +196,6 @@ class TestDocumentRagStreaming:
# Arrange & Act # Arrange & Act
result = await document_rag_streaming.query( result = await document_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=5, doc_limit=5,
streaming=True, streaming=True,
@ -223,7 +218,6 @@ class TestDocumentRagStreaming:
# Act # Act
result = await document_rag_streaming.query( result = await document_rag_streaming.query(
query="unknown topic", query="unknown topic",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=10, doc_limit=10,
streaming=True, streaming=True,
@ -247,7 +241,6 @@ class TestDocumentRagStreaming:
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
await document_rag_streaming.query( await document_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=5, doc_limit=5,
streaming=True, streaming=True,
@ -272,7 +265,6 @@ class TestDocumentRagStreaming:
# Act # Act
result = await document_rag_streaming.query( result = await document_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
doc_limit=limit, doc_limit=limit,
streaming=True, streaming=True,
@ -300,7 +292,6 @@ class TestDocumentRagStreaming:
# Act # Act
await document_rag_streaming.query( await document_rag_streaming.query(
query="test query", query="test query",
user=user,
collection=collection, collection=collection,
doc_limit=10, doc_limit=10,
streaming=True, streaming=True,
@ -309,5 +300,4 @@ class TestDocumentRagStreaming:
# Assert - Verify user/collection were passed to document embeddings client # Assert - Verify user/collection were passed to document embeddings client
call_args = mock_doc_embeddings_client.query.call_args call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection assert call_args.kwargs['collection'] == collection

View file

@ -146,7 +146,6 @@ class TestGraphRagIntegration:
# Act # Act
response = await graph_rag.query( response = await graph_rag.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
entity_limit=entity_limit, entity_limit=entity_limit,
triple_limit=triple_limit, triple_limit=triple_limit,
@ -163,7 +162,6 @@ class TestGraphRagIntegration:
call_args = mock_graph_embeddings_client.query.call_args call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph # 3. Should query triples to build knowledge subgraph
@ -204,7 +202,6 @@ class TestGraphRagIntegration:
# Act # Act
await graph_rag.query( await graph_rag.query(
query=query, query=query,
user="test_user",
collection="test_collection", collection="test_collection",
entity_limit=config["entity_limit"], entity_limit=config["entity_limit"],
triple_limit=config["triple_limit"] triple_limit=config["triple_limit"]
@ -224,7 +221,6 @@ class TestGraphRagIntegration:
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection" collection="test_collection"
) )
@ -247,7 +243,6 @@ class TestGraphRagIntegration:
# Act # Act
response = await graph_rag.query( response = await graph_rag.query(
query="unknown topic", query="unknown topic",
user="test_user",
collection="test_collection", collection="test_collection",
explain_callback=collect_provenance explain_callback=collect_provenance
) )
@ -267,7 +262,6 @@ class TestGraphRagIntegration:
# First query # First query
await graph_rag.query( await graph_rag.query(
query=query, query=query,
user="test_user",
collection="test_collection" collection="test_collection"
) )
@ -277,7 +271,6 @@ class TestGraphRagIntegration:
# Second identical query # Second identical query
await graph_rag.query( await graph_rag.query(
query=query, query=query,
user="test_user",
collection="test_collection" collection="test_collection"
) )
@ -289,26 +282,27 @@ class TestGraphRagIntegration:
assert second_call_count >= 0 # Should complete without errors assert second_call_count >= 0 # Should complete without errors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client): async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different users/collections are properly isolated""" """Test that different collections propagate through to the embeddings query.
Workspace isolation is enforced by flow.workspace at the service
boundary not by parameters on GraphRag.query so this test
verifies collection routing only.
"""
# Arrange # Arrange
query = "test query" query = "test query"
user1, collection1 = "user1", "collection1" collection1 = "collection1"
user2, collection2 = "user2", "collection2" collection2 = "collection2"
# Act # Act
await graph_rag.query(query=query, user=user1, collection=collection1) await graph_rag.query(query=query, collection=collection1)
await graph_rag.query(query=query, user=user2, collection=collection2) await graph_rag.query(query=query, collection=collection2)
# Assert - Both users should have separate queries # Assert - Each call propagated its collection
assert mock_graph_embeddings_client.query.call_count == 2 assert mock_graph_embeddings_client.query.call_count == 2
# Verify first call
first_call = mock_graph_embeddings_client.query.call_args_list[0] first_call = mock_graph_embeddings_client.query.call_args_list[0]
assert first_call.kwargs['user'] == user1
assert first_call.kwargs['collection'] == collection1 assert first_call.kwargs['collection'] == collection1
# Verify second call
second_call = mock_graph_embeddings_client.query.call_args_list[1] second_call = mock_graph_embeddings_client.query.call_args_list[1]
assert second_call.kwargs['user'] == user2
assert second_call.kwargs['collection'] == collection2 assert second_call.kwargs['collection'] == collection2

View file

@ -116,7 +116,6 @@ class TestGraphRagStreaming:
# Act - query() returns response, provenance via callback # Act - query() returns response, provenance via callback
response = await graph_rag_streaming.query( response = await graph_rag_streaming.query(
query=query, query=query,
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collector.collect, chunk_callback=collector.collect,
@ -154,7 +153,6 @@ class TestGraphRagStreaming:
# Act - Non-streaming # Act - Non-streaming
non_streaming_response = await graph_rag_streaming.query( non_streaming_response = await graph_rag_streaming.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
streaming=False streaming=False
) )
@ -167,7 +165,6 @@ class TestGraphRagStreaming:
streaming_response = await graph_rag_streaming.query( streaming_response = await graph_rag_streaming.query(
query=query, query=query,
user=user,
collection=collection, collection=collection,
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -189,7 +186,6 @@ class TestGraphRagStreaming:
# Act # Act
response = await graph_rag_streaming.query( response = await graph_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=callback chunk_callback=callback
@ -209,7 +205,6 @@ class TestGraphRagStreaming:
# Arrange & Act # Arrange & Act
response = await graph_rag_streaming.query( response = await graph_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=None # No callback provided chunk_callback=None # No callback provided
@ -231,7 +226,6 @@ class TestGraphRagStreaming:
# Act # Act
response = await graph_rag_streaming.query( response = await graph_rag_streaming.query(
query="unknown topic", query="unknown topic",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=callback chunk_callback=callback
@ -253,7 +247,6 @@ class TestGraphRagStreaming:
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
await graph_rag_streaming.query( await graph_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=callback chunk_callback=callback
@ -273,7 +266,6 @@ class TestGraphRagStreaming:
# Act # Act
await graph_rag_streaming.query( await graph_rag_streaming.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
entity_limit=entity_limit, entity_limit=entity_limit,
triple_limit=triple_limit, triple_limit=triple_limit,

View file

@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
triples_obj = Triples( triples_obj = Triples(
metadata=Metadata( metadata=Metadata(
id=f"export-msg-{i}", id=f"export-msg-{i}",
user=msg_data["metadata"]["user"],
collection=msg_data["metadata"]["collection"], collection=msg_data["metadata"]["collection"],
), ),
triples=to_subgraph(msg_data["triples"]), triples=to_subgraph(msg_data["triples"]),

View file

@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
return Chunk( return Chunk(
metadata=Metadata( metadata=Metadata(
id="doc-123", id="doc-123",
user="test_user",
collection="test_collection", collection="test_collection",
), ),
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns." chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="test-doc", id="test-doc",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="test-doc", id="test-doc",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_triples = Triples( sample_triples = Triples(
metadata=Metadata( metadata=Metadata(
id="test-doc", id="test-doc",
user="test_user",
collection="test_collection", collection="test_collection",
), ),
triples=[ triples=[
@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = sample_triples mock_msg.value.return_value = sample_triples
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act # Act
await processor.on_triples(mock_msg, None, None) await processor.on_triples(mock_msg, None, mock_flow)
# Assert # Assert
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples) mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store): async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_embeddings = GraphEmbeddings( sample_embeddings = GraphEmbeddings(
metadata=Metadata( metadata=Metadata(
id="test-doc", id="test-doc",
user="test_user",
collection="test_collection", collection="test_collection",
), ),
entities=[ entities=[
@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = sample_embeddings mock_msg.value.return_value = sample_embeddings
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act # Act
await processor.on_graph_embeddings(mock_msg, None, None) await processor.on_graph_embeddings(mock_msg, None, mock_flow)
# Assert # Assert
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings) mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor, async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration:
) )
sample_chunk = Chunk( sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"), metadata=Metadata(id="test", collection="collection"),
chunk=b"Test chunk" chunk=b"Test chunk"
) )
@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange # Arrange
large_chunk_batch = [ large_chunk_batch = [
Chunk( Chunk(
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"), metadata=Metadata(id=f"doc-{i}", collection="collection"),
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8") chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
) )
for i in range(100) # Large batch for i in range(100) # Large batch
@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange # Arrange
original_metadata = Metadata( original_metadata = Metadata(
id="test-doc-123", id="test-doc-123",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
entity_contexts_call = entity_contexts_producer.send.call_args[0][0] entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
assert triples_call.metadata.id == "test-doc-123" assert triples_call.metadata.id == "test-doc-123"
assert triples_call.metadata.user == "test_user"
assert triples_call.metadata.collection == "test_collection" assert triples_call.metadata.collection == "test_collection"
assert entity_contexts_call.metadata.id == "test-doc-123" assert entity_contexts_call.metadata.id == "test-doc-123"
assert entity_contexts_call.metadata.user == "test_user"
assert entity_contexts_call.metadata.collection == "test_collection" assert entity_contexts_call.metadata.collection == "test_collection"

View file

@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
) )
# Set up schemas # Set up schemas
proc.schemas = sample_schemas proc.schemas = {"default": dict(sample_schemas)}
# Mock the client method # Mock the client method
proc.client = MagicMock() proc.client = MagicMock()
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
} }
# Act - Update configuration # Act - Update configuration
await integration_processor.on_schema_config(new_schema_config, "v2") await integration_processor.on_schema_config("default", new_schema_config, "v2")
# Arrange - Test query using new schema # Arrange - Test query using new schema
request = QuestionToStructuredQueryRequest( request = QuestionToStructuredQueryRequest(
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
await integration_processor.on_message(msg, consumer, flow) await integration_processor.on_message(msg, consumer, flow)
# Assert # Assert
assert "inventory" in integration_processor.schemas assert "inventory" in integration_processor.schemas["default"]
response_call = flow_response.send.call_args response_call = flow_response.send.call_args
response = response_call[0][0] response = response_call[0][0]
assert response.detected_schemas == ["inventory"] assert response.detected_schemas == ["inventory"]
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
graphql_generation_template="custom-graphql-generator" graphql_generation_template="custom-graphql-generator"
) )
custom_processor.schemas = sample_schemas custom_processor.schemas = {"default": dict(sample_schemas)}
custom_processor.client = MagicMock() custom_processor.client = MagicMock()
request = QuestionToStructuredQueryRequest( request = QuestionToStructuredQueryRequest(
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)] ] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
) )
integration_processor.schemas.update(large_schema_set) integration_processor.schemas["default"].update(large_schema_set)
request = QuestionToStructuredQueryRequest( request = QuestionToStructuredQueryRequest(
question="Show me data from table_05 and table_12", question="Show me data from table_05 and table_12",
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
msg.properties.return_value = {"id": f"concurrent-test-{i}"} msg.properties.return_value = {"id": f"concurrent-test-{i}"}
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock() flow_response = AsyncMock()
flow.return_value = flow_response flow.return_value = flow_response

View file

@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock() return AsyncMock()
context.side_effect = context_router context.side_effect = context_router
context.workspace = "default"
return context return context
@pytest.mark.asyncio @pytest.mark.asyncio
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Act # Act
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Assert # Assert
assert len(processor.schemas) == 2 ws_schemas = processor.schemas["default"]
assert "customer_records" in processor.schemas assert len(ws_schemas) == 2
assert "product_catalog" in processor.schemas assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
# Verify customer schema # Verify customer schema
customer_schema = processor.schemas["customer_records"] customer_schema = ws_schemas["customer_records"]
assert customer_schema.name == "customer_records" assert customer_schema.name == "customer_records"
assert len(customer_schema.fields) == 4 assert len(customer_schema.fields) == 4
# Verify product schema # Verify product schema
product_schema = processor.schemas["product_catalog"] product_schema = ws_schemas["product_catalog"]
assert product_schema.name == "product_catalog" assert product_schema.name == "product_catalog"
assert len(product_schema.fields) == 4 assert len(product_schema.fields) == 4
@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings
# Load configuration # Load configuration
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Create realistic customer data chunk # Create realistic customer data chunk
metadata = Metadata( metadata = Metadata(
id="customer-doc-001", id="customer-doc-001",
user="integration_test",
collection="test_documents", collection="test_documents",
) )
@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings
# Load configuration # Load configuration
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Create realistic product data chunk # Create realistic product data chunk
metadata = Metadata( metadata = Metadata(
id="product-doc-001", id="product-doc-001",
user="integration_test",
collection="test_documents", collection="test_documents",
) )
@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings
# Load configuration # Load configuration
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Create multiple test chunks # Create multiple test chunks
chunks_data = [ chunks_data = [
@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
for chunk_id, text in chunks_data: for chunk_id, text in chunks_data:
metadata = Metadata( metadata = Metadata(
id=chunk_id, id=chunk_id,
user="concurrent_test",
collection="test_collection", collection="test_collection",
) )
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8')) chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
"customer_records": integration_config["schema"]["customer_records"] "customer_records": integration_config["schema"]["customer_records"]
} }
} }
await processor.on_schema_config(initial_config, version=1) await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1 ws_schemas = processor.schemas["default"]
assert "customer_records" in processor.schemas assert len(ws_schemas) == 1
assert "product_catalog" not in processor.schemas assert "customer_records" in ws_schemas
assert "product_catalog" not in ws_schemas
# Act - Reload with full configuration # Act - Reload with full configuration
await processor.on_schema_config(integration_config, version=2) await processor.on_schema_config("default", integration_config, version=2)
# Assert # Assert
assert len(processor.schemas) == 2 ws_schemas = processor.schemas["default"]
assert "customer_records" in processor.schemas assert len(ws_schemas) == 2
assert "product_catalog" in processor.schemas assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_error_resilience_integration(self, integration_config): async def test_error_resilience_integration(self, integration_config):
@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock() return AsyncMock()
failing_flow.side_effect = failing_context_router failing_flow.side_effect = failing_context_router
failing_flow.workspace = "default"
processor.flow = failing_flow processor.flow = failing_flow
# Load configuration # Load configuration
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Create test chunk # Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test") metadata = Metadata(id="error-test", collection="test")
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process") chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
mock_msg = MagicMock() mock_msg = MagicMock()
@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings
# Load configuration # Load configuration
await processor.on_schema_config(integration_config, version=1) await processor.on_schema_config("default", integration_config, version=1)
# Create chunk with rich metadata # Create chunk with rich metadata
original_metadata = Metadata( original_metadata = Metadata(
id="metadata-test-chunk", id="metadata-test-chunk",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
assert extracted_obj is not None assert extracted_obj is not None
# Verify metadata propagation # Verify metadata propagation
assert extracted_obj.metadata.user == "test_user"
assert extracted_obj.metadata.collection == "test_collection" assert extracted_obj.metadata.collection == "test_collection"
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference

View file

@ -87,6 +87,7 @@ class TestPromptStreaming:
return AsyncMock() return AsyncMock()
context.side_effect = context_router context.side_effect = context_router
context.workspace = "default"
return context return context
@pytest.fixture @pytest.fixture
@ -109,7 +110,7 @@ class TestPromptStreaming:
def prompt_processor_streaming(self, mock_prompt_manager): def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support""" """Create Prompt processor with streaming support"""
processor = MagicMock() processor = MagicMock()
processor.manager = mock_prompt_manager processor.managers = {"default": mock_prompt_manager}
processor.config_key = "prompt" processor.config_key = "prompt"
# Bind the actual on_request method # Bind the actual on_request method
@ -248,6 +249,7 @@ class TestPromptStreaming:
return AsyncMock() return AsyncMock()
context.side_effect = context_router context.side_effect = context_router
context.workspace = "default"
request = PromptRequest( request = PromptRequest(
id="test_prompt", id="test_prompt",
@ -341,6 +343,7 @@ class TestPromptStreaming:
return AsyncMock() return AsyncMock()
context.side_effect = context_router context.side_effect = context_router
context.workspace = "default"
request = PromptRequest( request = PromptRequest(
id="test_prompt", id="test_prompt",

View file

@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
# Act # Act
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=callback chunk_callback=callback
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
# Act # Act
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
# Act # Act
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
# Act # Act
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
# Act # Act
await graph_rag.query( await graph_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
# Act # Act
await document_rag.query( await document_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=callback chunk_callback=callback
@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
# Act # Act
await document_rag.query( await document_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect
@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
# Act # Act
await document_rag.query( await document_rag.query(
query="test query", query="test query",
user="test_user",
collection="test_collection", collection="test_collection",
streaming=True, streaming=True,
chunk_callback=collect chunk_callback=collect

View file

@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
@pytest.mark.integration @pytest.mark.integration
class TestRowsCassandraIntegration: class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table""" """Integration tests for Cassandra row storage with unified table"""
@ -125,14 +136,13 @@ class TestRowsCassandraIntegration:
} }
} }
await processor.on_schema_config(config, version=1) await processor.on_schema_config("default", config, version=1)
assert "customer_records" in processor.schemas assert "customer_records" in processor.schemas["default"]
# Step 2: Process an ExtractedObject # Step 2: Process an ExtractedObject
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="doc-001", id="doc-001",
user="test_user",
collection="import_2024", collection="import_2024",
), ),
schema_name="customer_records", schema_name="customer_records",
@ -149,7 +159,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = test_obj msg.value.return_value = test_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Verify Cassandra interactions # Verify Cassandra interactions
assert mock_cluster.connect.called assert mock_cluster.connect.called
@ -158,7 +168,7 @@ class TestRowsCassandraIntegration:
keyspace_calls = [call for call in mock_session.execute.call_args_list keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)] if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1 assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0]) assert "default" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table) # Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list table_calls = [call for call in mock_session.execute.call_args_list
@ -209,12 +219,12 @@ class TestRowsCassandraIntegration:
} }
} }
await processor.on_schema_config(config, version=1) await processor.on_schema_config("default", config, version=1)
assert len(processor.schemas) == 2 assert len(processor.schemas["default"]) == 2
# Process objects for different schemas # Process objects for different schemas
product_obj = ExtractedObject( product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog"), metadata=Metadata(id="p1", collection="catalog"),
schema_name="products", schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9, confidence=0.9,
@ -222,7 +232,7 @@ class TestRowsCassandraIntegration:
) )
order_obj = ExtractedObject( order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales"), metadata=Metadata(id="o1", collection="sales"),
schema_name="orders", schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85, confidence=0.85,
@ -233,7 +243,7 @@ class TestRowsCassandraIntegration:
for obj in [product_obj, order_obj]: for obj in [product_obj, order_obj]:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = obj msg.value.return_value = obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# All data goes into the same unified rows table # All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list table_calls = [call for call in mock_session.execute.call_args_list
@ -256,7 +266,8 @@ class TestRowsCassandraIntegration:
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields # Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema( processor.schemas["default"] = {
"indexed_data": RowSchema(
name="indexed_data", name="indexed_data",
fields=[ fields=[
Field(name="id", type="string", size=50, primary=True), Field(name="id", type="string", size=50, primary=True),
@ -265,9 +276,10 @@ class TestRowsCassandraIntegration:
Field(name="description", type="string", size=200) # Not indexed Field(name="description", type="string", size=200) # Not indexed
] ]
) )
}
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"), metadata=Metadata(id="t1", collection="test"),
schema_name="indexed_data", schema_name="indexed_data",
values=[{ values=[{
"id": "123", "id": "123",
@ -282,7 +294,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = test_obj msg.value.return_value = test_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Should have 3 data inserts (one per indexed field: id, category, status) # Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -342,13 +354,12 @@ class TestRowsCassandraIntegration:
} }
} }
await processor.on_schema_config(config, version=1) await processor.on_schema_config("default", config, version=1)
# Process batch object with multiple values # Process batch object with multiple values
batch_obj = ExtractedObject( batch_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="batch-001", id="batch-001",
user="test_user",
collection="batch_import", collection="batch_import",
), ),
schema_name="batch_customers", schema_name="batch_customers",
@ -376,7 +387,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = batch_obj msg.value.return_value = batch_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Verify unified table creation # Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list table_calls = [call for call in mock_session.execute.call_args_list
@ -396,14 +407,16 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema( processor.schemas["default"] = {
"empty_test": RowSchema(
name="empty_test", name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)] fields=[Field(name="id", type="string", size=50, primary=True)]
) )
}
# Process empty batch object # Process empty batch object
empty_obj = ExtractedObject( empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty"), metadata=Metadata(id="empty-1", collection="empty"),
schema_name="empty_test", schema_name="empty_test",
values=[], # Empty batch values=[], # Empty batch
confidence=1.0, confidence=1.0,
@ -413,7 +426,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = empty_obj msg.value.return_value = empty_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Should not create any data insert statements for empty batch # Should not create any data insert statements for empty batch
# (partition registration may still happen) # (partition registration may still happen)
@ -428,7 +441,8 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema( processor.schemas["default"] = {
"map_test": RowSchema(
name="map_test", name="map_test",
fields=[ fields=[
Field(name="id", type="string", size=50, primary=True), Field(name="id", type="string", size=50, primary=True),
@ -436,9 +450,10 @@ class TestRowsCassandraIntegration:
Field(name="count", type="integer", size=0) Field(name="count", type="integer", size=0)
] ]
) )
}
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"), metadata=Metadata(id="t1", collection="test"),
schema_name="map_test", schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}], values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9, confidence=0.9,
@ -448,7 +463,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = test_obj msg.value.return_value = test_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Verify insert uses map for data # Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -473,16 +488,18 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema( processor.schemas["default"] = {
"partition_test": RowSchema(
name="partition_test", name="partition_test",
fields=[ fields=[
Field(name="id", type="string", size=50, primary=True), Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True) Field(name="category", type="string", size=50, indexed=True)
] ]
) )
}
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection"), metadata=Metadata(id="t1", collection="my_collection"),
schema_name="partition_test", schema_name="partition_test",
values=[{"id": "123", "category": "test"}], values=[{"id": "123", "category": "test"}],
confidence=0.9, confidence=0.9,
@ -492,7 +509,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = test_obj msg.value.return_value = test_obj
await processor.on_object(msg, None, None) await processor.on_object(msg, None, mock_flow_default)
# Verify partition registration # Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list partition_inserts = [call for call in mock_session.execute.call_args_list

View file

@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_schema_configuration_and_generation(self, processor, sample_schema_config): async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
"""Test schema configuration loading and GraphQL schema generation""" """Test schema configuration loading and GraphQL schema generation"""
# Load schema configuration # Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
# Verify schemas were loaded # Verify schemas were loaded
assert len(processor.schemas) == 2 assert len(processor.schemas) == 2
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config): async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
"""Test Cassandra connection and dynamic table creation""" """Test Cassandra connection and dynamic table creation"""
# Load schema configuration # Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra # Connect to Cassandra
processor.connect_cassandra() processor.connect_cassandra()
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config): async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
"""Test inserting data and querying via GraphQL""" """Test inserting data and querying via GraphQL"""
# Load schema and connect # Load schema and connect
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() processor.connect_cassandra()
# Setup test data # Setup test data
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_query_with_filters(self, processor, sample_schema_config): async def test_graphql_query_with_filters(self, processor, sample_schema_config):
"""Test GraphQL queries with filtering on indexed fields""" """Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup) # Setup (reuse previous setup)
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() processor.connect_cassandra()
keyspace = "test_user" keyspace = "test_user"
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_error_handling(self, processor, sample_schema_config): async def test_graphql_error_handling(self, processor, sample_schema_config):
"""Test GraphQL error handling for invalid queries""" """Test GraphQL error handling for invalid queries"""
# Setup # Setup
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
# Test invalid field query # Test invalid field query
invalid_query = ''' invalid_query = '''
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_message_processing_integration(self, processor, sample_schema_config): async def test_message_processing_integration(self, processor, sample_schema_config):
"""Test full message processing workflow""" """Test full message processing workflow"""
# Setup # Setup
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() processor.connect_cassandra()
# Create mock message # Create mock message
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_concurrent_queries(self, processor, sample_schema_config): async def test_concurrent_queries(self, processor, sample_schema_config):
"""Test handling multiple concurrent GraphQL queries""" """Test handling multiple concurrent GraphQL queries"""
# Setup # Setup
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() processor.connect_cassandra()
# Create multiple query tasks # Create multiple query tasks
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
} }
} }
await processor.on_schema_config(initial_config, version=1) await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1 assert len(processor.schemas) == 1
assert "simple" in processor.schemas assert "simple" in processor.schemas
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
} }
} }
await processor.on_schema_config(updated_config, version=2) await processor.on_schema_config("default", updated_config, version=2)
# Verify updated schemas # Verify updated schemas
assert len(processor.schemas) == 2 assert len(processor.schemas) == 2
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_large_result_set_handling(self, processor, sample_schema_config): async def test_large_result_set_handling(self, processor, sample_schema_config):
"""Test handling of large query result sets""" """Test handling of large query result sets"""
# Setup # Setup
await processor.on_schema_config(sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() processor.connect_cassandra()
keyspace = "large_test_user" keyspace = "large_test_user"
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
} }
} }
await processor.on_schema_config(schema_config, version=1) await processor.on_schema_config("default", schema_config, version=1)
# Measure query execution time # Measure query execution time
start_time = time.time() start_time = time.time()

View file

@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
# Arrange - Create realistic query request # Arrange - Create realistic query request
request = StructuredQueryRequest( request = StructuredQueryRequest(
question="Show me all customers from California who have made purchases over $500", question="Show me all customers from California who have made purchases over $500",
user="trustgraph",
collection="default" collection="default"
) )
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
assert "orders" in objects_call_args.query assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
assert objects_call_args.variables["state"] == "California" assert objects_call_args.variables["state"] == "California"
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default" assert objects_call_args.collection == "default"
# Verify response # Verify response

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager # Setup mock agent manager
mock_agent_instance = AsyncMock() mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks # Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = AgentRequest( msg.value.return_value = AgentRequest(
question="What is 2 + 2?", question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode streaming=False # Non-streaming mode
) )
msg.properties.return_value = {"id": "test-id"} msg.properties.return_value = {"id": "test-id"}
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock # Setup flow mock
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock() mock_producer = AsyncMock()
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager # Setup mock agent manager
mock_agent_instance = AsyncMock() mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly # Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = AgentRequest( msg.value.return_value = AgentRequest(
question="What is 2 + 2?", question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode streaming=False # Non-streaming mode
) )
msg.properties.return_value = {"id": "test-id"} msg.properties.return_value = {"id": "test-id"}
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock # Setup flow mock
consumer = MagicMock() consumer = MagicMock()
flow = MagicMock() flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock() mock_producer = AsyncMock()

View file

@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser", def _make_request(question="Test question",
collection="default", streaming=False, collection="default", streaming=False,
session_id="parent-session", task_type="research", session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"): framing="test framing", conversation_id="conv-1"):
return AgentRequest( return AgentRequest(
question=question, question=question,
user=user,
collection=collection, collection=collection,
streaming=streaming, streaming=streaming,
session_id=session_id, session_id=session_id,
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
req = agg.build_synthesis_request( req = agg.build_synthesis_request(
"corr-1", "corr-1",
original_question="Original question", original_question="Original question",
user="testuser",
collection="default", collection="default",
) )
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-b", "answer-b") agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request( req = agg.build_synthesis_request(
"corr-1", "question", "user", "default", "corr-1", "question", "default",
) )
# Last history step should be the synthesis step # Last history step should be the synthesis step
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-a", "answer-a") agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request( agg.build_synthesis_request(
"corr-1", "question", "user", "default", "corr-1", "question", "default",
) )
# Entry should be removed # Entry should be removed
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
agg = Aggregator() agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"): with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request( agg.build_synthesis_request(
"unknown", "question", "user", "default", "unknown", "question", "default",
) )

View file

@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs): def _make_request(**kwargs):
defaults = dict( defaults = dict(
question="Test question", question="Test question",
user="testuser",
collection="default", collection="default",
) )
defaults.update(kwargs) defaults.update(kwargs)
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
synth = agg.build_synthesis_request( synth = agg.build_synthesis_request(
"corr-1", "corr-1",
original_question="Original question", original_question="Original question",
user="testuser",
collection="default", collection="default",
) )
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
agg.record_completion("corr-1", "goal", "answer") agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request( synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default", "corr-1", "question", "default",
) )
# correlation_id must be empty so it's not intercepted # correlation_id must be empty so it's not intercepted

View file

@ -126,7 +126,6 @@ def make_base_request(**kwargs):
state="", state="",
group=[], group=[],
history=[], history=[],
user="testuser",
collection="default", collection="default",
streaming=False, streaming=False,
session_id="test-session-123", session_id="test-session-123",

View file

@ -21,7 +21,6 @@ class MockProcessor:
def _make_request(**kwargs): def _make_request(**kwargs):
defaults = dict( defaults = dict(
question="Test question", question="Test question",
user="testuser",
collection="default", collection="default",
) )
defaults.update(kwargs) defaults.update(kwargs)

View file

@ -167,39 +167,28 @@ class TestToolServiceRequest:
"""Test cases for tool service request format""" """Test cases for tool service request format"""
def test_request_format(self): def test_request_format(self):
"""Test that request is properly formatted with user, config, and arguments""" """Test that request is properly formatted with config and arguments"""
# Arrange
user = "alice"
config_values = {"style": "pun", "collection": "jokes"} config_values = {"style": "pun", "collection": "jokes"}
arguments = {"topic": "programming"} arguments = {"topic": "programming"}
# Act - simulate request building
request = { request = {
"user": user,
"config": json.dumps(config_values), "config": json.dumps(config_values),
"arguments": json.dumps(arguments) "arguments": json.dumps(arguments)
} }
# Assert
assert request["user"] == "alice"
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"} assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
assert json.loads(request["arguments"]) == {"topic": "programming"} assert json.loads(request["arguments"]) == {"topic": "programming"}
def test_request_with_empty_config(self): def test_request_with_empty_config(self):
"""Test request when no config values are provided""" """Test request when no config values are provided"""
# Arrange
user = "bob"
config_values = {} config_values = {}
arguments = {"query": "test"} arguments = {"query": "test"}
# Act
request = { request = {
"user": user,
"config": json.dumps(config_values) if config_values else "{}", "config": json.dumps(config_values) if config_values else "{}",
"arguments": json.dumps(arguments) if arguments else "{}" "arguments": json.dumps(arguments) if arguments else "{}"
} }
# Assert
assert request["config"] == "{}" assert request["config"] == "{}"
assert json.loads(request["arguments"]) == {"query": "test"} assert json.loads(request["arguments"]) == {"query": "test"}
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
assert map_topic_to_category("random topic") == "default" assert map_topic_to_category("random topic") == "default"
assert map_topic_to_category("") == "default" assert map_topic_to_category("") == "default"
def test_joke_response_personalization(self): def test_joke_response_format(self):
"""Test that joke responses include user personalization""" """Test that joke response is formatted as expected"""
# Arrange
user = "alice"
style = "pun" style = "pun"
joke = "Why do programmers prefer dark mode? Because light attracts bugs!" joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
# Act response = f"Here's a {style} for you:\n\n{joke}"
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
# Assert
assert "Hey alice!" in response
assert "pun" in response assert "pun" in response
assert joke in response assert joke in response
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
def test_request_parsing(self): def test_request_parsing(self):
"""Test parsing of incoming request""" """Test parsing of incoming request"""
# Arrange
request_data = { request_data = {
"user": "alice",
"config": '{"style": "pun"}', "config": '{"style": "pun"}',
"arguments": '{"topic": "programming"}' "arguments": '{"topic": "programming"}'
} }
# Act
user = request_data.get("user", "trustgraph")
config = json.loads(request_data["config"]) if request_data["config"] else {} config = json.loads(request_data["config"]) if request_data["config"] else {}
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {} arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
# Assert
assert user == "alice"
assert config == {"style": "pun"} assert config == {"style": "pun"}
assert arguments == {"topic": "programming"} assert arguments == {"topic": "programming"}

View file

@ -1,6 +1,6 @@
""" """
Tests for tool service lifecycle, invoke contract, streaming responses, Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation. and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts. classes rather than plain dicts.
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
svc = DynamicToolService.__new__(DynamicToolService) svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {}) await svc.invoke({}, {})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self): async def test_on_request_calls_invoke_with_parsed_args(self):
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
calls = [] calls = []
async def tracking_invoke(user, config, arguments): async def tracking_invoke(config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments}) calls.append({"config": config, "arguments": arguments})
return "ok" return "ok"
svc.invoke = tracking_invoke svc.invoke = tracking_invoke
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest( msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}', config='{"style": "pun"}',
arguments='{"topic": "cats"}', arguments='{"topic": "cats"}',
) )
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
await svc.on_request(msg, MagicMock(), None) await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1 assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"} assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"} assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self): async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string.""" """String return from invoke → response field is the string."""
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc" svc.id = "test-svc"
svc.producer = AsyncMock() svc.producer = AsyncMock()
async def string_invoke(user, config, arguments): async def string_invoke(config, arguments):
return "hello world" return "hello world"
svc.invoke = string_invoke svc.invoke = string_invoke
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock() DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"} msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None) await svc.on_request(msg, MagicMock(), None)
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc" svc.id = "test-svc"
svc.producer = AsyncMock() svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments): async def dict_invoke(config, arguments):
return {"result": 42} return {"result": 42}
svc.invoke = dict_invoke svc.invoke = dict_invoke
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock() DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"} msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None) await svc.on_request(msg, MagicMock(), None)
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc" svc.id = "test-svc"
svc.producer = AsyncMock() svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments): async def failing_invoke(config, arguments):
raise ValueError("bad input") raise ValueError("bad input")
svc.invoke = failing_invoke svc.invoke = failing_invoke
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"} msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None) await svc.on_request(msg, MagicMock(), None)
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc" svc.id = "test-svc"
svc.producer = AsyncMock() svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments): async def rate_limited_invoke(config, arguments):
raise TooManyRequests("rate limited") raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke svc.invoke = rate_limited_invoke
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"} msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests): with pytest.raises(TooManyRequests):
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc" svc.id = "test-svc"
svc.producer = AsyncMock() svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments): async def ok_invoke(config, arguments):
return "ok" return "ok"
svc.invoke = ok_invoke svc.invoke = ok_invoke
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock() DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"} msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None) await svc.on_request(msg, MagicMock(), None)
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService) svc = ToolService.__new__(ToolService)
svc.id = "test-tool" svc.id = "test-tool"
async def mock_invoke(name, params): async def mock_invoke(workspace, name, params):
return "tool result" return "tool result"
svc.invoke_tool = mock_invoke svc.invoke_tool = mock_invoke
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub} flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow" flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService) svc = ToolService.__new__(ToolService)
svc.id = "test-tool" svc.id = "test-tool"
async def mock_invoke(name, params): async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]} return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke svc.invoke_tool = mock_invoke
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub} flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow" flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService) svc = ToolService.__new__(ToolService)
svc.id = "test-tool" svc.id = "test-tool"
async def failing_invoke(name, params): async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke") raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke svc.invoke_tool = failing_invoke
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub} flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow" flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService) svc = ToolService.__new__(ToolService)
svc.id = "test-tool" svc.id = "test-tool"
async def rate_limited(name, params): async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down") raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited svc.invoke_tool = rate_limited
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
flow = MagicMock() flow = MagicMock()
flow.producer = {"response": AsyncMock()} flow.producer = {"response": AsyncMock()}
flow.name = "test-flow" flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests): with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow) await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
received = {} received = {}
async def capture_invoke(name, params): async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name received["name"] = name
received["params"] = params received["params"] = params
return "ok" return "ok"
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub flow = lambda name: mock_pub
flow.producer = {"response": mock_pub} flow.producer = {"response": mock_pub}
flow.name = "f" flow.name = "f"
flow.workspace = "default"
msg = MagicMock() msg = MagicMock()
msg.value.return_value = ToolRequest( msg.value.return_value = ToolRequest(
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
)) ))
result = await client.call( result = await client.call(
user="alice",
config={"style": "pun"}, config={"style": "pun"},
arguments={"topic": "cats"}, arguments={"topic": "cats"},
) )
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
req = client.request.call_args[0][0] req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest) assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"} assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"} assert json.loads(req.arguments) == {"topic": "cats"}
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
)) ))
with pytest.raises(RuntimeError, match="service down"): with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={}) await client.call(config={}, arguments={})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self): async def test_call_empty_config_sends_empty_json(self):
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
error=None, response="ok", error=None, response="ok",
)) ))
await client.call(user="u", config=None, arguments=None) await client.call(config=None, arguments=None)
req = client.request.call_args[0][0] req = client.request.call_args[0][0]
assert req.config == "{}" assert req.config == "{}"
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
error=None, response="ok", error=None, response="ok",
)) ))
await client.call(user="u", config={}, arguments={}, timeout=30) await client.call(config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args _, kwargs = client.request.call_args
assert kwargs["timeout"] == 30 assert kwargs["timeout"] == 30
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
received.append(text) received.append(text)
result = await client.call_streaming( result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback, config={}, arguments={}, callback=callback,
) )
assert result == "chunk1chunk2" assert result == "chunk1chunk2"
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
with pytest.raises(RuntimeError, match="stream failed"): with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming( await client.call_streaming(
user="u", config={}, arguments={}, config={}, arguments={},
callback=AsyncMock(), callback=AsyncMock(),
) )
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
received.append(text) received.append(text)
result = await client.call_streaming( result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback, config={}, arguments={}, callback=callback,
) )
# Empty response is falsy, so callback shouldn't be called for it # Empty response is falsy, so callback shouldn't be called for it
assert result == "data" assert result == "data"
assert received == ["data"] assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -1,17 +1,14 @@
""" """
Tests for AsyncProcessor config notify pattern: Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering - register_config_handler with types filtering
- on_config_notify version comparison and type matching - on_config_notify version comparison, type/workspace matching
- fetch_config with short-lived client - fetch_and_apply_config retry logic over per-workspace fetches
- fetch_and_apply_config retry logic
""" """
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture @pytest.fixture
def processor(): def processor():
"""Create an AsyncProcessor with mocked dependencies.""" """Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2 assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify: class TestOnConfigNotify:
@pytest.mark.asyncio @pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"]) processor.register_config_handler(handler, types=["prompt"])
msg = Mock() msg = _notify_msg(3, {"prompt": ["default"]})
msg.value.return_value = Mock(version=3, types=["prompt"])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_not_called() handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"]) processor.register_config_handler(handler, types=["prompt"])
msg = Mock() msg = _notify_msg(5, {"prompt": ["default"]})
msg.value.return_value = Mock(version=5, types=["prompt"])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_not_called() handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"]) processor.register_config_handler(handler, types=["prompt"])
msg = Mock() msg = _notify_msg(2, {"schema": ["default"]})
msg.value.return_value = Mock(version=2, types=["schema"])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_not_called() handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"]) processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config mock_client = AsyncMock()
mock_config = {"prompt": {"key": "value"}}
with patch.object( with patch.object(
processor, 'fetch_config', processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=(mock_config, 2) return_value={"key": "value"},
): ):
msg = Mock() msg = _notify_msg(2, {"prompt": ["default"]})
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2) handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2 assert processor.config_version == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor): async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1 processor.config_version = 1
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler) # No types = all processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2) handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor): async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"]) processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler) processor.register_config_handler(all_handler)
mock_config = {"prompt": {}} mock_client = AsyncMock()
with patch.object( with patch.object(
processor, 'fetch_config', processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=(mock_config, 2) return_value={},
): ):
msg = Mock() msg = _notify_msg(2, {"prompt": ["default"]})
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once() prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called() schema_handler.assert_not_called()
all_handler.assert_called_once() all_handler.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor): async def test_multi_workspace_notify_invokes_handler_per_ws(
"""Empty types list (startup signal) should invoke all handlers.""" self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1 processor.config_version = 1
h1 = AsyncMock() handler = AsyncMock()
h2 = AsyncMock() processor.register_config_handler(handler, types=["prompt"])
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_config = {} mock_client = AsyncMock()
with patch.object( with patch.object(
processor, 'fetch_config', processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=(mock_config, 2) return_value={},
): ):
msg = Mock() msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
msg.value.return_value = Mock(version=2, types=[])
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
h1.assert_called_once() assert handler.call_count == 2
h2.assert_called_once() called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor): async def test_fetch_failure_handled(self, processor):
processor.config_version = 1 processor.config_version = 1
handler = AsyncMock() handler = AsyncMock()
processor.register_config_handler(handler) processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object( with patch.object(
processor, 'fetch_config', processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed") side_effect=RuntimeError("Connection failed"),
): ):
msg = Mock() msg = _notify_msg(2, {"prompt": ["default"]})
msg.value.return_value = Mock(version=2, types=["prompt"])
# Should not raise # Should not raise
await processor.on_config_notify(msg, None, None) await processor.on_config_notify(msg, None, None)
handler.assert_not_called() handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig: class TestFetchAndApplyConfig:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor): async def test_applies_config_per_workspace(self, processor):
h1 = AsyncMock() """Startup fetch invokes handler once per workspace affected."""
h2 = AsyncMock() h = AsyncMock()
processor.register_config_handler(h1, types=["prompt"]) processor.register_config_handler(h, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object( with patch.object(
processor, 'fetch_config', processor, '_create_config_client', return_value=mock_client
new_callable=AsyncMock, ), patch.object(
return_value=(mock_config, 10) processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
): ):
await processor.fetch_and_apply_config() await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type assert h.call_count == 2
h1.assert_called_once_with(mock_config, 10) call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
h2.assert_called_once_with(mock_config, 10) assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10 assert processor.config_version == 10
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retries_on_failure(self, processor): async def test_handler_without_types_skipped_at_startup(self, processor):
call_count = 0 """Handlers registered without types fetch nothing at startup."""
mock_config = {"prompt": {}} typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch(): mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
if call_count < 3: if call_count < 3:
raise RuntimeError("not ready") raise RuntimeError("not ready")
return mock_config, 5 return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \ mock_client = AsyncMock()
patch('asyncio.sleep', new_callable=AsyncMock): with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config() await processor.fetch_and_apply_config()
assert call_count == 3 assert call_count == 3
assert processor.config_version == 5 assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
result = await client.query( result = await client.query(
vector=vector, vector=vector,
limit=10, limit=10,
user="test_user",
collection="test_collection", collection="test_collection",
timeout=30 timeout=30
) )
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
assert isinstance(call_args, DocumentEmbeddingsRequest) assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vector == vector assert call_args.vector == vector
assert call_args.limit == 10 assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection" assert call_args.collection == "test_collection"
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__') @patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client.request.assert_called_once() client.request.assert_called_once()
call_args = client.request.call_args[0][0] call_args = client.request.call_args[0][0]
assert call_args.limit == 20 # Default limit assert call_args.limit == 20 # Default limit
assert call_args.user == "trustgraph" # Default user
assert call_args.collection == "default" # Default collection assert call_args.collection == "default" # Default collection
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__') @patch('trustgraph.base.request_response_spec.RequestResponse.__init__')

View file

@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
spec_two = MagicMock() spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two]) processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", processor, {"answer": 42}) flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1" assert flow.id == "processor-1"
assert flow.name == "flow-a" assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {} assert flow.producer == {}
assert flow.consumer == {} assert flow.consumer == {}
assert flow.parameter == {} assert flow.parameter == {}
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
def test_flow_start_and_stop_visit_all_consumers(): def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock() consumer_one = AsyncMock()
consumer_two = AsyncMock() consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two} flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start()) asyncio.run(flow.start())
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
def test_flow_call_returns_values_in_priority_order(): def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value" flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value" flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value" flow.consumer["shared"] = "consumer-value"

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'} flow_defn = {'config': 'test-config'}
# Act # Act
await processor.start_flow(flow_name, flow_defn) await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications # Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn) mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications # The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation) # (The exact mechanism depends on Flow implementation)

View file

@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
flow_name = 'test-flow' flow_name = 'test-flow'
flow_defn = {'config': 'test-config'} flow_defn = {'config': 'test-config'}
await processor.start_flow(flow_name, flow_defn) await processor.start_flow("default", flow_name, flow_defn)
assert flow_name in processor.flows assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with( mock_flow_class.assert_called_once_with(
'test-processor', flow_name, processor, flow_defn 'test-processor', flow_name, "default", processor, flow_defn
) )
mock_flow.start.assert_called_once() mock_flow.start.assert_called_once()
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_flow_class.return_value = mock_flow mock_flow_class.return_value = mock_flow
flow_name = 'test-flow' flow_name = 'test-flow'
await processor.start_flow(flow_name, {'config': 'test-config'}) await processor.start_flow("default", flow_name, {'config': 'test-config'})
await processor.stop_flow(flow_name) await processor.stop_flow("default", flow_name)
assert flow_name not in processor.flows assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once() mock_flow.stop.assert_called_once()
@with_async_processor_patches @with_async_processor_patches
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
processor = FlowProcessor(**config) processor = FlowProcessor(**config)
await processor.stop_flow('non-existent-flow') await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {} assert processor.flows == {}
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
} }
} }
await processor.on_configure_flows(config_data, version=1) await processor.on_configure_flows("default", config_data, version=1)
assert 'test-flow' in processor.flows assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with( mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', processor, 'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'} {'config': 'test-config'}
) )
mock_flow.start.assert_called_once() mock_flow.start.assert_called_once()
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
} }
} }
await processor.on_configure_flows(config_data, version=1) await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {} assert processor.flows == {}
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
'other-data': 'some-value' 'other-data': 'some-value'
} }
await processor.on_configure_flows(config_data, version=1) await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {} assert processor.flows == {}
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
} }
} }
await processor.on_configure_flows(config_data1, version=1) await processor.on_configure_flows("default", config_data1, version=1)
config_data2 = { config_data2 = {
'processor:test-processor': { 'processor:test-processor': {
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
} }
} }
await processor.on_configure_flows(config_data2, version=2) await processor.on_configure_flows("default", config_data2, version=2)
assert 'flow1' not in processor.flows assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once() mock_flow1.stop.assert_called_once()
assert 'flow2' in processor.flows assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once() mock_flow2.start.assert_called_once()
@with_async_processor_patches @with_async_processor_patches

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text.""" """Sample document with moderate length text."""
metadata = Metadata( metadata = Metadata(
id="test-doc-1", id="test-doc-1",
user="test-user",
collection="test-collection" collection="test-collection"
) )
text = "The quick brown fox jumps over the lazy dog. " * 20 text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks.""" """Long document for testing multiple chunks."""
metadata = Metadata( metadata = Metadata(
id="test-doc-long", id="test-doc-long",
user="test-user",
collection="test-collection" collection="test-collection"
) )
# Create a long text that will definitely be chunked # Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters.""" """Document with various unicode characters."""
metadata = Metadata( metadata = Metadata(
id="test-doc-unicode", id="test-doc-unicode",
user="test-user",
collection="test-collection" collection="test-collection"
) )
text = """ text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing.""" """Empty document for edge case testing."""
metadata = Metadata( metadata = Metadata(
id="test-doc-empty", id="test-doc-empty",
user="test-user",
collection="test-collection" collection="test-collection"
) )
return TextDocument( return TextDocument(

View file

@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock() mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata( mock_text_doc.metadata = Metadata(
id="test-doc-123", id="test-doc-123",
user="test-user",
collection="test-collection" collection="test-collection"
) )
mock_text_doc.text = b"This is test document content" mock_text_doc.text = b"This is test document content"

View file

@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock() mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata( mock_text_doc.metadata = Metadata(
id="test-doc-456", id="test-doc-456",
user="test-user",
collection="test-collection" collection="test-collection"
) )
mock_text_doc.text = b"This is test document content for token chunking" mock_text_doc.text = b"This is test document content for token chunking"

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
format_type='json', format_type='json',
token=None token=None,
workspace='default'
) )
def test_list_main_uses_defaults(self): def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/', url='http://localhost:8088/',
config_type='prompt', config_type='prompt',
format_type='text', format_type='text',
token=None token=None,
workspace='default'
) )
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt', config_type='prompt',
key='template-1', key='template-1',
format_type='json', format_type='json',
token=None token=None,
workspace='default'
) )
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt', config_type='prompt',
key='new-template', key='new-template',
value='Custom prompt: {input}', value='Custom prompt: {input}',
token=None token=None,
workspace='default'
) )
def test_put_main_with_stdin_arg(self): def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt', config_type='prompt',
key='stdin-template', key='stdin-template',
value=stdin_content, value=stdin_content,
token=None token=None,
workspace='default'
) )
def test_put_main_mutually_exclusive_args(self): def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
key='old-template', key='old-template',
token=None token=None,
workspace='default'
) )

View file

@ -48,7 +48,7 @@ def knowledge_loader():
return KnowledgeLoader( return KnowledgeLoader(
files=["test.ttl"], files=["test.ttl"],
flow="test-flow", flow="test-flow",
user="test-user", workspace="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc-123", document_id="test-doc-123",
url="http://test.example.com/", url="http://test.example.com/",
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
loader = KnowledgeLoader( loader = KnowledgeLoader(
files=["file1.ttl", "file2.ttl"], files=["file1.ttl", "file2.ttl"],
flow="my-flow", flow="my-flow",
user="user1", workspace="user1",
collection="col1", collection="col1",
document_id="doc1", document_id="doc1",
url="http://example.com/", url="http://example.com/",
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
assert loader.files == ["file1.ttl", "file2.ttl"] assert loader.files == ["file1.ttl", "file2.ttl"]
assert loader.flow == "my-flow" assert loader.flow == "my-flow"
assert loader.user == "user1" assert loader.workspace == "user1"
assert loader.collection == "col1" assert loader.collection == "col1"
assert loader.document_id == "doc1" assert loader.document_id == "doc1"
assert loader.url == "http://example.com/" assert loader.url == "http://example.com/"
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader( loader = KnowledgeLoader(
files=[f.name], files=[f.name],
flow="test-flow", flow="test-flow",
user="test-user", workspace="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc", document_id="test-doc",
url="http://test.example.com/" url="http://test.example.com/"
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader( loader = KnowledgeLoader(
files=[temp_turtle_file], files=[temp_turtle_file],
flow="test-flow", flow="test-flow",
user="test-user", workspace="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc", document_id="test-doc",
url="http://test.example.com/", url="http://test.example.com/",
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters # Verify Api was created with correct parameters
mock_api_class.assert_called_once_with( mock_api_class.assert_called_once_with(
url="http://test.example.com/", url="http://test.example.com/",
token="test-token" token="test-token",
workspace="test-user"
) )
# Verify bulk client was obtained # Verify bulk client was obtained
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
call_args = mock_bulk.import_triples.call_args call_args = mock_bulk.import_triples.call_args
assert call_args[1]['flow'] == "test-flow" assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc" assert call_args[1]['metadata']['id'] == "test-doc"
assert call_args[1]['metadata']['user'] == "test-user"
assert call_args[1]['metadata']['collection'] == "test-collection" assert call_args[1]['metadata']['collection'] == "test-collection"
# Verify import_entity_contexts was called # Verify import_entity_contexts was called
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
'tg-load-knowledge', 'tg-load-knowledge',
'-i', 'doc-123', '-i', 'doc-123',
'-f', 'my-flow', '-f', 'my-flow',
'-U', 'my-user', '-w', 'my-user',
'-C', 'my-collection', '-C', 'my-collection',
'-u', 'http://custom.example.com/', '-u', 'http://custom.example.com/',
'-t', 'my-token', '-t', 'my-token',
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
token='my-token', token='my-token',
flow='my-flow', flow='my-flow',
files=['file1.ttl', 'file2.ttl'], files=['file1.ttl', 'file2.ttl'],
user='my-user', workspace='my-user',
collection='my-collection' collection='my-collection'
) )
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
# Verify defaults were used # Verify defaults were used
call_args = mock_loader_class.call_args[1] call_args = mock_loader_class.call_args[1]
assert call_args['flow'] == 'default' assert call_args['flow'] == 'default'
assert call_args['user'] == 'trustgraph' assert call_args['workspace'] == 'default'
assert call_args['collection'] == 'default' assert call_args['collection'] == 'default'
assert call_args['url'] == 'http://localhost:8088/' assert call_args['url'] == 'http://localhost:8088/'
assert call_args['token'] is None assert call_args['token'] is None
@ -287,7 +287,7 @@ class TestErrorHandling:
loader = KnowledgeLoader( loader = KnowledgeLoader(
files=[temp_turtle_file], files=[temp_turtle_file],
flow="test-flow", flow="test-flow",
user="test-user", workspace="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc", document_id="test-doc",
url="http://test.example.com/" url="http://test.example.com/"

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None, group=None,
state=None, state=None,
applicable_states=None, applicable_states=None,
token=None token=None,
workspace='default'
) )
def test_set_main_structured_query_no_arguments_needed(self): def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None, group=None,
state=None, state=None,
applicable_states=None, applicable_states=None,
token=None token=None,
workspace='default'
) )
def test_valid_types_includes_row_embeddings_query(self): def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main() show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None) mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery: class TestShowToolsRowEmbeddingsQuery:

View file

@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
# Act # Act
result = client.request( result = client.request(
vector=vector, vector=vector,
user="test_user",
collection="test_collection", collection="test_collection",
limit=10, limit=10,
timeout=300 timeout=300
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert # Assert
assert result == ["chunk1", "chunk2", "chunk3"] assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with( client.call.assert_called_once_with(
user="test_user",
collection="test_collection", collection="test_collection",
vector=vector, vector=vector,
limit=10, limit=10,
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert # Assert
assert result == ["test_chunk"] assert result == ["test_chunk"]
client.call.assert_called_once_with( client.call.assert_called_once_with(
user="trustgraph",
collection="default", collection="default",
vector=vector, vector=vector,
limit=10, limit=10,

View file

@ -31,7 +31,6 @@ def _make_query(
query = Query( query = Query(
rag=rag, rag=rag,
user="test-user",
collection="test-collection", collection="test-collection",
verbose=False, verbose=False,
entity_limit=entity_limit, entity_limit=entity_limit,
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
assert calls[0].kwargs["p"] is None assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15 assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection" assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20 assert calls[0].kwargs["batch_size"] == 20

View file

@ -28,6 +28,7 @@ def mock_flow_config():
"""Mock flow configuration.""" """Mock flow configuration."""
mock_config = Mock() mock_config = Mock()
mock_config.flows = { mock_config.flows = {
"test-user": {
"test-flow": { "test-flow": {
"interfaces": { "interfaces": {
"triples-store": {"flow": "test-triples-queue"}, "triples-store": {"flow": "test-triples-queue"},
@ -35,6 +36,7 @@ def mock_flow_config():
} }
} }
} }
}
mock_config.pulsar_client = AsyncMock() mock_config.pulsar_client = AsyncMock()
return mock_config return mock_config
@ -43,7 +45,7 @@ def mock_flow_config():
def mock_request(): def mock_request():
"""Mock knowledge load request.""" """Mock knowledge load request."""
request = Mock() request = Mock()
request.user = "test-user" request.workspace = "test-user"
request.id = "test-doc-id" request.id = "test-doc-id"
request.collection = "test-collection" request.collection = "test-collection"
request.flow = "test-flow" request.flow = "test-flow"
@ -71,7 +73,6 @@ def sample_triples():
return Triples( return Triples(
metadata=Metadata( metadata=Metadata(
id="test-doc-id", id="test-doc-id",
user="test-user",
collection="default", # This should be overridden collection="default", # This should be overridden
), ),
triples=[ triples=[
@ -90,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings( return GraphEmbeddings(
metadata=Metadata( metadata=Metadata(
id="test-doc-id", id="test-doc-id",
user="test-user",
collection="default", # This should be overridden collection="default", # This should be overridden
), ),
entities=[ entities=[
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once() mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1] sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection" assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id" assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once() mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1] sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection" assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id" assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core falls back to 'default' when request.collection is None.""" """Test that load_kg_core falls back to 'default' when request.collection is None."""
# Create request with None collection # Create request with None collection
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_request.collection = None # Should fall back to "default" mock_request.collection = None # Should fall back to "default"
mock_request.flow = "test-flow" mock_request.flow = "test-flow"
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core validates flow configuration before processing.""" """Test that load_kg_core validates flow configuration before processing."""
# Request with invalid flow # Request with invalid flow
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_request.collection = "test-collection" mock_request.collection = "test-collection"
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
# Test missing ID # Test missing ID
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_request.id = None # Missing mock_request.id = None # Missing
mock_request.collection = "test-collection" mock_request.collection = "test-collection"
mock_request.flow = "test-flow" mock_request.flow = "test-flow"
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples): async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
"""Test that get_kg_core preserves collection field from stored data.""" """Test that get_kg_core preserves collection field from stored data."""
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_respond = AsyncMock() mock_respond = AsyncMock()
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_list_kg_cores(self, knowledge_manager): async def test_list_kg_cores(self, knowledge_manager):
"""Test listing knowledge cores.""" """Test listing knowledge cores."""
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_respond = AsyncMock() mock_respond = AsyncMock()
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_delete_kg_core(self, knowledge_manager): async def test_delete_kg_core(self, knowledge_manager):
"""Test deleting knowledge cores.""" """Test deleting knowledge cores."""
mock_request = Mock() mock_request = Mock()
mock_request.user = "test-user" mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_respond = AsyncMock() mock_respond = AsyncMock()

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data # Mock message with inline data
content = b"# Document Title\nBody text content." content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message # Mock message
content = b"fake pdf" content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
] ]
content = b"fake pdf" content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,

View file

@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_basic(self): def test_make_safe_collection_name_basic(self):
"""Test basic collection name creation""" """Test basic collection name creation"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="test_user", workspace="test_user",
collection="test_collection", collection="test_collection",
prefix="doc" prefix="doc"
) )
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_special_characters(self): def test_make_safe_collection_name_with_special_characters(self):
"""Test collection name creation with special characters that need sanitization""" """Test collection name creation with special characters that need sanitization"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="user@domain.com", workspace="user@domain.com",
collection="test-collection.v2", collection="test-collection.v2",
prefix="entity" prefix="entity"
) )
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_unicode(self): def test_make_safe_collection_name_with_unicode(self):
"""Test collection name creation with Unicode characters""" """Test collection name creation with Unicode characters"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="测试用户", workspace="测试用户",
collection="colección_española", collection="colección_española",
prefix="doc" prefix="doc"
) )
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_spaces(self): def test_make_safe_collection_name_with_spaces(self):
"""Test collection name creation with spaces""" """Test collection name creation with spaces"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="test user", workspace="test user",
collection="my test collection", collection="my test collection",
prefix="entity" prefix="entity"
) )
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self): def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
"""Test collection name creation with multiple consecutive special characters""" """Test collection name creation with multiple consecutive special characters"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="user@@@domain!!!", workspace="user@@@domain!!!",
collection="test---collection...v2", collection="test---collection...v2",
prefix="doc" prefix="doc"
) )
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_leading_trailing_underscores(self): def test_make_safe_collection_name_with_leading_trailing_underscores(self):
"""Test collection name creation with leading/trailing special characters""" """Test collection name creation with leading/trailing special characters"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="__test_user__", workspace="__test_user__",
collection="@@test_collection##", collection="@@test_collection##",
prefix="entity" prefix="entity"
) )
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_user(self): def test_make_safe_collection_name_empty_user(self):
"""Test collection name creation with empty user (should fallback to 'default')""" """Test collection name creation with empty user (should fallback to 'default')"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="", workspace="",
collection="test_collection", collection="test_collection",
prefix="doc" prefix="doc"
) )
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_collection(self): def test_make_safe_collection_name_empty_collection(self):
"""Test collection name creation with empty collection (should fallback to 'default')""" """Test collection name creation with empty collection (should fallback to 'default')"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="test_user", workspace="test_user",
collection="", collection="",
prefix="doc" prefix="doc"
) )
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_both_empty(self): def test_make_safe_collection_name_both_empty(self):
"""Test collection name creation with both user and collection empty""" """Test collection name creation with both user and collection empty"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="", workspace="",
collection="", collection="",
prefix="doc" prefix="doc"
) )
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_only_special_characters(self): def test_make_safe_collection_name_only_special_characters(self):
"""Test collection name creation with only special characters (should fallback to 'default')""" """Test collection name creation with only special characters (should fallback to 'default')"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="@@@!!!", workspace="@@@!!!",
collection="---###", collection="---###",
prefix="entity" prefix="entity"
) )
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_whitespace_only(self): def test_make_safe_collection_name_whitespace_only(self):
"""Test collection name creation with whitespace-only strings""" """Test collection name creation with whitespace-only strings"""
result = make_safe_collection_name( result = make_safe_collection_name(
user=" \n\t ", workspace=" \n\t ",
collection=" \r\n ", collection=" \r\n ",
prefix="doc" prefix="doc"
) )
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_mixed_valid_invalid_chars(self): def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
"""Test collection name creation with mixed valid and invalid characters""" """Test collection name creation with mixed valid and invalid characters"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="user123@test", workspace="user123@test",
collection="coll_2023.v1", collection="coll_2023.v1",
prefix="entity" prefix="entity"
) )
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
long_collection = "b" * 100 long_collection = "b" * 100
result = make_safe_collection_name( result = make_safe_collection_name(
user=long_user, workspace=long_user,
collection=long_collection, collection=long_collection,
prefix="doc" prefix="doc"
) )
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_numeric_values(self): def test_make_safe_collection_name_numeric_values(self):
"""Test collection name creation with numeric user/collection values""" """Test collection name creation with numeric user/collection values"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="user123", workspace="user123",
collection="collection456", collection="collection456",
prefix="doc" prefix="doc"
) )
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_case_sensitivity(self): def test_make_safe_collection_name_case_sensitivity(self):
"""Test that collection name creation preserves case""" """Test that collection name creation preserves case"""
result = make_safe_collection_name( result = make_safe_collection_name(
user="TestUser", workspace="TestUser",
collection="TestCollection", collection="TestCollection",
prefix="Doc" prefix="Doc"
) )

View file

@ -20,9 +20,8 @@ def processor():
) )
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
user="test", collection="default"): metadata = Metadata(id=doc_id, collection=collection)
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = value msg.value.return_value = value
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metadata_preserved(self, processor): async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata.""" """Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1") msg = _make_chunk_message(collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse( mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]] error=None, vectors=[[0.0]]
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow) await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0] result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports" assert result.metadata.collection == "reports"
assert result.metadata.id == "d1" assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id) return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"): def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection) metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities) value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = value msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}") _make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5) for i in range(5)
] ]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5) mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock() mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list: for call in mock_output.send.call_args_list:
result = call[0][0] result = call[0][0]
assert result.metadata.id == "doc-42" assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main" assert result.metadata.collection == "main"
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
} }
} }
await processor.on_schema_config(config_data, 1) await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas assert 'customers' in processor.schemas["default"]
assert processor.schemas['customers'].name == 'customers' assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3 assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self): async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully""" """Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {} 'other_type': {}
} }
await processor.on_schema_config(config_data, 1) await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {} assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self): async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped""" """Test that messages for unknown collections are dropped"""
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {} processor.known_collections[('default', 'test_collection')] = {}
# No schemas registered # No schemas registered
metadata = MagicMock() metadata = MagicMock()
@ -322,10 +322,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {} processor.known_collections[('default', 'test_collection')] = {}
# Set up schema # Set up schema
processor.schemas['customers'] = RowSchema( processor.schemas["default"] = {
'customers': RowSchema(
name='customers', name='customers',
description='Customer records', description='Customer records',
fields=[ fields=[
@ -333,6 +334,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
Field(name='name', type='text', indexed=True), Field(name='name', type='text', indexed=True),
] ]
) )
}
metadata = MagicMock() metadata = MagicMock()
metadata.user = 'test_user' metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock() return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory) mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow) await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -0,0 +1,200 @@
"""
Unit tests for extract_with_simplified_format.
Regression guard for the bug where the extractor read
``result.object`` (singular, used for response_type="json") instead of
``result.objects`` (plural, used for response_type="jsonl"). The
extract-with-ontologies prompt is JSONL, so reading the wrong field
silently dropped every extraction and left the knowledge graph
populated only by ontology schema + document provenance.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.base import PromptResult
@pytest.fixture
def extractor():
"""Create a Processor instance without running its heavy __init__.
Matches the pattern used in test_prompt_and_extraction.py: only
the attributes the code under test touches need to be set.
"""
ex = object.__new__(Processor)
ex.URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
return ex
@pytest.fixture
def food_subset():
"""A minimal food ontology subset the extracted entities reference."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe.",
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food.",
},
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"comment": "Relates a recipe to its ingredients.",
"domain": "Recipe",
"range": "Food",
},
},
datatype_properties={},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/",
},
)
def _flow_with_prompt_result(prompt_result):
"""Build the ``flow(name)`` callable the extractor invokes.
``extract_with_simplified_format`` calls
``flow("prompt-request").prompt(...)`` so we need ``flow`` to be
callable, return an object whose ``.prompt`` is an AsyncMock that
resolves to ``prompt_result``.
"""
prompt_service = MagicMock()
prompt_service.prompt = AsyncMock(return_value=prompt_result)
def flow(name):
assert name == "prompt-request", (
f"extractor should only invoke flow('prompt-request'), "
f"got {name!r}"
)
return prompt_service
return flow, prompt_service.prompt
class TestReadsObjectsForJsonlPrompt:
"""extract-with-ontologies is a JSONL prompt; the extractor must
read ``result.objects``, not ``result.object``."""
async def test_populated_objects_produces_triples(
self, extractor, food_subset,
):
"""Happy path: PromptResult with populated .objects -> non-empty
triples list."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[
{"type": "entity", "entity": "Cornish Pasty",
"entity_type": "Recipe"},
{"type": "entity", "entity": "beef",
"entity_type": "Food"},
{"type": "relationship",
"subject": "Cornish Pasty", "subject_type": "Recipe",
"relation": "ingredients",
"object": "beef", "object_type": "Food"},
],
)
flow, prompt_mock = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "some chunk", food_subset, {"text": "some chunk"},
)
prompt_mock.assert_awaited_once()
assert triples, (
"extract_with_simplified_format returned no triples; if "
"this fails, the extractor is probably reading .object "
"instead of .objects again"
)
async def test_none_objects_returns_empty_without_crashing(
self, extractor, food_subset,
):
"""The exact shape that hit production on v2.3: the extractor
was reading ``.object`` for a JSONL prompt, which returned
``None`` and tripped the parser's 'Unexpected response type'
path. With the fix we read ``.objects``; if that's also
``None`` we must still return ``[]`` cleanly, not crash."""
prompt_result = PromptResult(
response_type="jsonl",
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_empty_objects_returns_empty(
self, extractor, food_subset,
):
"""Valid JSONL response with zero entries should yield zero
triples, not raise."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[],
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_ignores_object_field_for_jsonl_prompt(
self, extractor, food_subset,
):
"""If ``.object`` is somehow set but ``.objects`` is None, the
extractor must not silently fall back to ``.object``. This
guards against a well-meaning regression that "helpfully"
re-adds fallback fields.
The extractor should read only ``.objects`` for this prompt;
when that is None we expect the empty-result path.
"""
prompt_result = PromptResult(
response_type="json",
object={"not": "the field we should be reading"},
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == [], (
"Extractor fell back to .object for a JSONL prompt — "
"this is the regression shape we are trying to prevent"
)

View file

@ -34,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition} return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
user="user-1", collection="col-1", document_id=""):
chunk = Chunk( chunk = Chunk(
metadata=Metadata( metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection, id=meta_id, root=root, collection=collection,
), ),
chunk=text.encode("utf-8"), chunk=text.encode("utf-8"),
document_id=document_id, document_id=document_id,
@ -229,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")] defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs) flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1", "text", meta_id="c-1", root="r-1", collection="coll-1",
user="u-1", collection="coll-1",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)
@ -238,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub): for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1" assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1" assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -247,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")] defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs) flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2", "text", meta_id="c-2", root="r-2", collection="coll-2",
user="u-2", collection="coll-2",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)

View file

@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
} }
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
user="user-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk.""" """Build a mock message wrapping a Chunk."""
chunk = Chunk( chunk = Chunk(
metadata=Metadata( metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection, id=meta_id, root=root, collection=collection,
), ),
chunk=text.encode("utf-8"), chunk=text.encode("utf-8"),
document_id=document_id, document_id=document_id,
@ -189,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")] rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels) flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1", "text", meta_id="c-1", root="r-1", collection="coll-1",
user="u-1", collection="coll-1",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)
@ -198,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub): for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1" assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1" assert triples_msg.metadata.collection == "coll-1"

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock() ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver: class TestConfigReceiver:
"""Test cases for ConfigReceiver class""" """Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_notify_new_version(self): async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Test on_config_notify triggers fetch for newer version""" """Notify with newer version fetches each affected workspace."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1 config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = [] fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version async def mock_fetch(workspace, retry=False):
mock_msg = Mock() fetch_calls.append(workspace)
mock_msg.value.return_value = Mock(version=2, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None) config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1 msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self): async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions""" """Older-version notifies are ignored."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5 config_receiver.config_version = 5
fetch_calls = [] fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version async def mock_fetch(workspace, retry=False):
mock_msg = Mock() fetch_calls.append(workspace)
mock_msg.value.return_value = Mock(version=3, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None) config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0 msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self): async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about""" """Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1 config_receiver.config_version = 1
fetch_calls = [] fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type async def mock_fetch(workspace, retry=False):
mock_msg = Mock() fetch_calls.append(workspace)
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
await config_receiver.on_config_notify(mock_msg, None, None) config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch msg = _notify(2, {"prompt": ["ws1"]})
assert len(fetch_calls) == 0 await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2 assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self): async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully""" """on_config_notify swallows exceptions from message decode."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock() mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception") mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None) await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self): async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""Test fetch_and_apply starts new flows""" """fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock() mock_resp = Mock()
mock_resp.error = None mock_resp.error = None
mock_resp.version = 5 mock_resp.version = 5
mock_resp.config = { mock_resp.config = {
"flow": { "flow": {
"flow1": '{"name": "test_flow_1"}', "flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}' "flow2": '{"name": "test_flow_2"}',
} }
} }
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client) config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = [] start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow)) async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply() await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5 assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2 assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self): async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""Test fetch_and_apply stops removed flows""" """fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = { config_receiver.flows = {
"default": {
"flow1": {"name": "test_flow_1"}, "flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"} "flow2": {"name": "test_flow_2"},
}
} }
# Config now only has flow1
mock_resp = Mock() mock_resp = Mock()
mock_resp.error = None mock_resp.error = None
mock_resp.version = 5 mock_resp.version = 5
mock_resp.config = { mock_resp.config = {
"flow": { "flow": {
"flow1": '{"name": "test_flow_1"}' "flow1": '{"name": "test_flow_1"}',
} }
} }
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client) config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = [] stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow)) async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply() await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1 assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2" assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self): async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Test fetch_and_apply with empty config""" """Empty workspace config clears any local flow state."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client) config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply() await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {} assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1 assert config_receiver.config_version == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_flow_with_handlers(self): async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers""" """start_flow fans out to every registered flow handler."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock() handler1 = Mock()
handler1.start_flow = Mock() handler1.start_flow = AsyncMock()
handler2 = Mock() handler2 = Mock()
handler2.start_flow = Mock() handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1) config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2) config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data) await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data) handler1.start_flow.assert_awaited_once_with(
handler2.start_flow.assert_called_once_with("flow1", flow_data) "default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self): async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions""" """Handler exceptions in start_flow do not propagate."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
handler = Mock() handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error")) handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler) config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
# Should not raise # Should not raise
await config_receiver.start_flow("flow1", flow_data) await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data) handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow_with_handlers(self): async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers""" """stop_flow fans out to every registered flow handler."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock() handler1 = Mock()
handler1.stop_flow = Mock() handler1.stop_flow = AsyncMock()
handler2 = Mock() handler2 = Mock()
handler2.stop_flow = Mock() handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1) config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2) config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data) await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data) handler1.stop_flow.assert_awaited_once_with(
handler2.stop_flow.assert_called_once_with("flow1", flow_data) "default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self): async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions""" """Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
handler = Mock() handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error")) handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler) config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
# Should not raise # Should not raise
await config_receiver.stop_flow("flow1", flow_data) await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data) handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task') @patch('asyncio.create_task')
@pytest.mark.asyncio @pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once() mock_create_task.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self): async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations""" """fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend) config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = { config_receiver.flows = {
"default": {
"flow1": {"name": "test_flow_1"}, "flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"} "flow2": {"name": "test_flow_2"},
}
} }
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock() mock_resp = Mock()
mock_resp.error = None mock_resp.error = None
mock_resp.version = 5 mock_resp.version = 5
mock_resp.config = { mock_resp.config = {
"flow": { "flow": {
"flow2": '{"name": "test_flow_2"}', "flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}' "flow3": '{"name": "test_flow_3"}',
} }
} }
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = [] start_calls = []
stop_calls = [] stop_calls = []
async def mock_start_flow(id, flow): async def mock_start_flow(workspace, id, flow):
start_calls.append((id, flow)) start_calls.append((workspace, id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow)) async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply() await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows ws_flows = config_receiver.flows["default"]
assert "flow2" in config_receiver.flows assert "flow1" not in ws_flows
assert "flow3" in config_receiver.flows assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1 assert len(start_calls) == 1
assert start_calls[0][0] == "flow3" assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1 assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1" assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -36,7 +36,6 @@ def _ge_response_dict():
"metadata": { "metadata": {
"id": "doc-1", "id": "doc-1",
"root": "", "root": "",
"user": "alice",
"collection": "testcoll", "collection": "testcoll",
}, },
"entities": [ "entities": [
@ -59,7 +58,6 @@ def _triples_response_dict():
"metadata": { "metadata": {
"id": "doc-1", "id": "doc-1",
"root": "", "root": "",
"user": "alice",
"collection": "testcoll", "collection": "testcoll",
}, },
"triples": [ "triples": [
@ -73,9 +71,9 @@ def _triples_response_dict():
} }
def _make_request(id_="doc-1", user="alice"): def _make_request(id_="doc-1", workspace="alice"):
request = Mock() request = Mock()
request.query = {"id": id_, "user": user} request.query = {"id": id_, "workspace": workspace}
return request return request
@ -149,12 +147,8 @@ class TestCoreExportWireFormat:
msg_type, payload = items[0] msg_type, payload = items[0]
assert msg_type == "ge" assert msg_type == "ge"
# Metadata envelope: only id/user/collection — no stale `m["m"]`. # Metadata envelope: only id/collection — no stale `m["m"]`.
assert payload["m"] == { assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
# Entities: each carries the *singular* `v` and the term envelope # Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2 assert len(payload["e"]) == 2
@ -202,11 +196,7 @@ class TestCoreExportWireFormat:
msg_type, payload = items[0] msg_type, payload = items[0]
assert msg_type == "t" assert msg_type == "t"
assert payload["m"] == { assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
assert len(payload["t"]) == 1 assert len(payload["t"]) == 1
@ -240,7 +230,7 @@ class TestCoreImportWireFormat:
payload = msgpack.packb(( payload = msgpack.packb((
"ge", "ge",
{ {
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, "m": {"i": "doc-1", "c": "testcoll"},
"e": [ "e": [
{ {
"e": {"t": "i", "i": "http://example.org/alice"}, "e": {"t": "i", "i": "http://example.org/alice"},
@ -266,7 +256,7 @@ class TestCoreImportWireFormat:
req = captured[0] req = captured[0]
assert req["operation"] == "put-kg-core" assert req["operation"] == "put-kg-core"
assert req["user"] == "alice" assert req["workspace"] == "alice"
assert req["id"] == "doc-1" assert req["id"] == "doc-1"
ge = req["graph-embeddings"] ge = req["graph-embeddings"]
@ -275,7 +265,6 @@ class TestCoreImportWireFormat:
assert "metadata" not in ge["metadata"] assert "metadata" not in ge["metadata"]
assert ge["metadata"] == { assert ge["metadata"] == {
"id": "doc-1", "id": "doc-1",
"user": "alice",
"collection": "default", "collection": "default",
} }
@ -302,7 +291,7 @@ class TestCoreImportWireFormat:
payload = msgpack.packb(( payload = msgpack.packb((
"t", "t",
{ {
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, "m": {"i": "doc-1", "c": "testcoll"},
"t": [ "t": [
{ {
"s": {"t": "i", "i": "http://example.org/alice"}, "s": {"t": "i", "i": "http://example.org/alice"},
@ -407,10 +396,9 @@ class TestCoreImportExportRoundTrip:
original = _ge_response_dict()["graph-embeddings"] original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"] ge = req["graph-embeddings"]
# The import side overrides id/user from the URL query (intentional), # The import side overrides id from the URL query (intentional),
# so we only round-trip the entity payload itself. # so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"] assert ge["metadata"]["id"] == original["metadata"]["id"]
assert ge["metadata"]["user"] == original["metadata"]["user"]
assert len(ge["entities"]) == len(original["entities"]) assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]): for got, want in zip(ge["entities"], original["entities"]):

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data) await manager.start_flow("default", "flow1", flow_data)
assert "flow1" in manager.flows assert ("default", "flow1") in manager.flows
assert manager.flows["flow1"] == flow_data assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow(self): async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow # Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("flow1", flow_data) await manager.stop_flow("default", "flow1", flow_data)
assert "flow1" not in manager.flows assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self): def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper""" """Test dispatch_global_service returns DispatcherWrapper"""
@ -275,7 +275,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"triples-store": {"flow": "test_queue"} "triples-store": {"flow": "test_queue"}
} }
@ -326,7 +326,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"triples-store": {"flow": "test_queue"} "triples-store": {"flow": "test_queue"}
} }
@ -348,7 +348,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"triples-store": {"flow": "test_queue"} "triples-store": {"flow": "test_queue"}
} }
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"} params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params) result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent") manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result" assert result == "flow_result"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary # Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}} manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher # Pre-populate with existing dispatcher
mock_dispatcher = Mock() mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result") mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result" assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"agent": { "agent": {
"request": "agent_request_queue", "request": "agent_request_queue",
@ -453,7 +453,7 @@ class TestDispatcherManager:
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters # Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
@ -461,14 +461,14 @@ class TestDispatcherManager:
request_queue="agent_request_queue", request_queue="agent_request_queue",
response_queue="agent_response_queue", response_queue="agent_response_queue",
timeout=120, timeout=120,
consumer="api-gateway-test_flow-agent-request", consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request" subscriber="api-gateway-default-test_flow-agent-request"
) )
mock_dispatcher.start.assert_called_once() mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached # Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result" assert result == "new_result"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -479,7 +479,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"text-load": {"flow": "text_load_queue"} "text-load": {"flow": "text_load_queue"}
} }
@ -497,7 +497,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load") result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters # Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
@ -508,7 +508,7 @@ class TestDispatcherManager:
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached # Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result" assert result == "sender_result"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"): with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self): async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface # Setup test flow without agent interface
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"text-completion": {"request": "req", "response": "resp"} "text-completion": {"request": "req", "response": "resp"}
} }
} }
with pytest.raises(RuntimeError, match="This kind not supported by flow"): with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent") await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self): async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind # Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"invalid-kind": {"request": "req", "response": "resp"} "invalid-kind": {"request": "req", "response": "resp"}
} }
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"): with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = { manager.flows[("default", "test_flow")] = {
"interfaces": { "interfaces": {
"agent": { "agent": {
"request": "agent_request_queue", "request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[ results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent") manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5) for _ in range(5)
]) ])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug" "Dispatcher class instantiated more than once — duplicate consumer bug"
) )
assert mock_dispatcher.start.call_count == 1 assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results) assert all(r == "result" for r in results)

View file

@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
assert isinstance(sent, EntityContexts) assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata) assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123" assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection" assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2 assert len(sent.entities) == 2

View file

@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
assert isinstance(sent, GraphEmbeddings) assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata) assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123" assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection" assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2 assert len(sent.entities) == 2

Some files were not shown because too many files have changed in this diff Show more