feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)

Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.

Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
  proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
  captures the workspace/collection/flow hierarchy.

Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
  DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
  Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
  service layer.
- Translators updated to not serialise/deserialise user.

API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.

Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
  scoped by workspace. Config client API takes workspace as first
  positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
  no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.

CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
  library) drop user kwargs from every method signature.

MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
  keyed per user.

Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
  whose blueprint template was parameterised AND no remaining
  live flow (across all workspaces) still resolves to that topic.
  Three scopes fall out naturally from template analysis:
    * {id} -> per-flow, deleted on stop
    * {blueprint} -> per-blueprint, kept while any flow of the
      same blueprint exists
    * {workspace} -> per-workspace, kept while any flow in the
      workspace exists
    * literal -> global, never deleted (e.g. tg.request.librarian)
  Fixes a bug where stopping a flow silently destroyed the global
  librarian exchange, wedging all library operations until manual
  restart.

RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
  dead connections (broker restart, orphaned channels, network
  partitions) within ~2 heartbeat windows, so the consumer
  reconnects and re-binds its queue rather than sitting forever
  on a zombie connection.

Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
  ~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
cybermaggedon 2026-04-21 23:23:01 +01:00 committed by GitHub
parent 9332089b3d
commit d35473f7f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
377 changed files with 6868 additions and 5785 deletions

View file

@ -125,21 +125,39 @@ class AsyncProcessor:
response_metrics = config_resp_metrics,
)
async def fetch_config(self):
"""Fetch full config from config service using a short-lived
request/response client. Returns (config, version) or raises."""
client = self._create_config_client()
try:
await client.start()
resp = await client.request(
ConfigRequest(operation="config"),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
return resp.config, resp.version
finally:
await client.stop()
async def _fetch_type_workspace(self, client, workspace, config_type):
"""Fetch config values of a single type within one workspace.
Returns dict of {key: value}."""
resp = await client.request(
ConfigRequest(
operation="getvalues",
workspace=workspace,
type=config_type,
),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
return {v.key: v.value for v in resp.values}
async def _fetch_type_all_workspaces(self, client, config_type):
"""Fetch config values of a single type across all workspaces.
Returns dict of {workspace: {key: value}}."""
resp = await client.request(
ConfigRequest(
operation="getvalues-all-ws",
type=config_type,
),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
grouped = {}
for v in resp.values:
ws = grouped.setdefault(v.workspace, {})
ws[v.key] = v.value
return grouped, resp.version
# This is called to start dynamic behaviour.
# Implements the subscribe-then-fetch pattern to avoid race conditions.
@ -155,21 +173,51 @@ class AsyncProcessor:
# processed by on_config_notify, which does the version check
async def fetch_and_apply_config(self):
"""Fetch full config from config service and apply to all handlers.
Retries until successful config service may not be ready yet."""
"""Startup: for each registered handler, fetch config for all its
types across all workspaces and invoke the handler once per
workspace. Retries until successful config service may not be
ready yet."""
while self.running:
try:
config, version = await self.fetch_config()
client = self._create_config_client()
try:
await client.start()
logger.info(f"Fetched config version {version}")
version = 0
self.config_version = version
for entry in self.config_handlers:
handler_types = entry["types"]
# Apply to all handlers (startup = invoke all)
for entry in self.config_handlers:
await entry["handler"](config, version)
# Handlers registered without types get nothing
# at startup (there is no "all types" fetch).
if not handler_types:
continue
# Group all registered types by workspace:
# {workspace: {type: {key: value}}}
per_ws = {}
for t in handler_types:
type_data, v = \
await self._fetch_type_all_workspaces(
client, t,
)
version = max(version, v)
for ws, kv in type_data.items():
per_ws.setdefault(ws, {})[t] = kv
# Call the handler once per workspace
for ws, config in per_ws.items():
await entry["handler"](ws, config, version)
logger.info(
f"Applied startup config version {version}"
)
self.config_version = version
finally:
await client.stop()
return
@ -204,8 +252,9 @@ class AsyncProcessor:
# Called when a config notify message arrives
async def on_config_notify(self, message, consumer, flow):
notify_version = message.value().version
notify_types = set(message.value().types)
v = message.value()
notify_version = v.version
changes = v.changes # dict of type -> [workspaces]
# Skip if we already have this version or newer
if notify_version <= self.config_version:
@ -215,41 +264,60 @@ class AsyncProcessor:
)
return
# Check if any handler cares about the affected types
if notify_types:
any_interested = False
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None or notify_types & handler_types:
any_interested = True
break
notify_types = set(changes.keys())
if not any_interested:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no handlers for types {notify_types}"
)
self.config_version = notify_version
return
# Filter out handlers that don't care about any of the changed
# types. A handler registered without types never fires on
# notifications (nothing to scope to).
interested = []
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types and notify_types & handler_types:
interested.append(entry)
if not interested:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no handlers for types {notify_types}"
)
self.config_version = notify_version
return
logger.info(
f"Config notify v{notify_version} types={list(notify_types)}, "
f"fetching config..."
f"Config notify v{notify_version} "
f"types={list(notify_types)}, fetching config..."
)
# Fetch full config using short-lived client
try:
config, version = await self.fetch_config()
client = self._create_config_client()
try:
await client.start()
self.config_version = version
for entry in interested:
handler_types = entry["types"]
# Invoke handlers that care about the affected types
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None:
await entry["handler"](config, version)
elif not notify_types or notify_types & handler_types:
await entry["handler"](config, version)
# Build {workspace: {type: {key: value}}} for types
# this handler cares about, where the workspace was
# affected for that type.
per_ws = {}
for t in handler_types:
if t not in changes:
continue
for ws in changes[t]:
kv = await self._fetch_type_workspace(
client, ws, t,
)
per_ws.setdefault(ws, {})[t] = kv
for ws, config in per_ws.items():
await entry["handler"](
ws, config, notify_version,
)
finally:
await client.stop()
self.config_version = notify_version
except Exception as e:
logger.error(

View file

@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor):
await super(ChunkingService, self).start()
await self.librarian.start()
async def get_document_text(self, doc):
async def get_document_text(self, doc, workspace):
"""
Get text content from a TextDocument, fetching from librarian if needed.
Args:
doc: TextDocument with either inline text or document_id
workspace: Workspace for librarian lookup (from flow.workspace)
Returns:
str: The document text content
@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor):
logger.info(f"Fetching document {doc.document_id} from librarian...")
text = await self.librarian.fetch_document_text(
document_id=doc.document_id,
user=doc.metadata.user,
workspace=workspace,
)
logger.info(f"Fetched {len(text)} characters from librarian")
return text

View file

@ -15,114 +15,139 @@ class CollectionConfigHandler:
Storage services should:
1. Inherit from this class along with their service base class
2. Call register_config_handler(self.on_collection_config) in __init__
3. Implement create_collection(user, collection, metadata) method
4. Implement delete_collection(user, collection) method
3. Implement create_collection(workspace, collection, metadata) method
4. Implement delete_collection(workspace, collection) method
"""
def __init__(self, **kwargs):
# Track known collections: {(user, collection): metadata_dict}
# Track known collections: {(workspace, collection): metadata_dict}
self.known_collections: Dict[tuple, dict] = {}
# Pass remaining kwargs up the inheritance chain
super().__init__(**kwargs)
async def on_collection_config(self, config: dict, version: int):
async def on_collection_config(
self, workspace: str, config: dict, version: int
):
"""
Handle config push messages and extract collection information
for a single workspace.
Args:
workspace: Workspace the config applies to
config: Configuration dictionary from ConfigPush message
version: Configuration version number
"""
logger.info(f"Processing collection configuration (version {version})")
logger.info(
f"Processing collection configuration "
f"(version {version}, workspace {workspace})"
)
# Extract collections from config (treat missing key as empty)
# Extract collections from config (treat missing key as empty).
# Each config key IS the collection name — config is already
# partitioned by workspace, so no workspace prefix is needed
# on the key.
collection_config = config.get("collection", {})
# Track which collections we've seen in this config
current_collections: Set[tuple] = set()
# Process each collection in the config
for key, value_json in collection_config.items():
for collection, value_json in collection_config.items():
try:
# Parse user:collection key
if ":" not in key:
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
continue
current_collections.add((workspace, collection))
user, collection = key.split(":", 1)
current_collections.add((user, collection))
# Parse metadata
metadata = json.loads(value_json)
# Check if this is a new collection or updated
collection_key = (user, collection)
if collection_key not in self.known_collections:
logger.info(f"New collection detected: {user}/{collection}")
await self.create_collection(user, collection, metadata)
self.known_collections[collection_key] = metadata
key = (workspace, collection)
if key not in self.known_collections:
logger.info(
f"New collection detected: {workspace}/{collection}"
)
await self.create_collection(
workspace, collection, metadata
)
self.known_collections[key] = metadata
else:
# Collection already exists, update metadata if changed
if self.known_collections[collection_key] != metadata:
logger.info(f"Collection metadata updated: {user}/{collection}")
# Most storage services don't need to do anything for metadata updates
# They just need to know the collection exists
self.known_collections[collection_key] = metadata
if self.known_collections[key] != metadata:
logger.info(
f"Collection metadata updated: "
f"{workspace}/{collection}"
)
self.known_collections[key] = metadata
except Exception as e:
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
logger.error(
f"Error processing collection config for "
f"{workspace}/{collection}: {e}",
exc_info=True,
)
# Find collections that were deleted (in known but not in current)
deleted_collections = set(self.known_collections.keys()) - current_collections
for user, collection in deleted_collections:
logger.info(f"Collection deleted: {user}/{collection}")
# Find collections for THIS workspace that were deleted (in
# known but not in current). Only compare collections owned by
# this workspace — other workspaces' collections are not
# affected by this config update.
known_for_ws = {
(w, c) for (w, c) in self.known_collections.keys()
if w == workspace
}
deleted_collections = known_for_ws - current_collections
for ws, collection in deleted_collections:
logger.info(f"Collection deleted: {ws}/{collection}")
try:
# Remove from known_collections FIRST to immediately reject new writes
# This eliminates race condition with worker threads
del self.known_collections[(user, collection)]
# Physical deletion happens after - worker threads already rejecting writes
await self.delete_collection(user, collection)
# Remove from known_collections FIRST to immediately
# reject new writes
del self.known_collections[(ws, collection)]
await self.delete_collection(ws, collection)
except Exception as e:
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
# If physical deletion failed, should we re-add to known_collections?
# For now, keep it removed - collection is logically deleted per config
logger.error(
f"Error deleting collection {ws}/{collection}: {e}",
exc_info=True,
)
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
logger.debug(
f"Collection config processing complete. "
f"Known collections: {len(self.known_collections)}"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(
self, workspace: str, collection: str, metadata: dict,
):
"""
Create a collection in the storage backend.
Subclasses must implement this method.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
metadata: Collection metadata dictionary
"""
raise NotImplementedError("Storage service must implement create_collection method")
raise NotImplementedError(
"Storage service must implement create_collection method"
)
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""
Delete a collection from the storage backend.
Subclasses must implement this method.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
"""
raise NotImplementedError("Storage service must implement delete_collection method")
raise NotImplementedError(
"Storage service must implement delete_collection method"
)
def collection_exists(self, user: str, collection: str) -> bool:
def collection_exists(self, workspace: str, collection: str) -> bool:
"""
Check if a collection is known to exist
Check if a collection is known to exist.
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
Returns:
True if collection exists, False otherwise
"""
return (user, collection) in self.known_collections
return (workspace, collection) in self.known_collections

View file

@ -18,10 +18,11 @@ class ConfigClient(RequestResponse):
)
return resp
async def get(self, type, key, timeout=CONFIG_TIMEOUT):
async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
"""Get a single config value. Returns the value string or None."""
resp = await self._request(
operation="get",
workspace=workspace,
keys=[ConfigKey(type=type, key=key)],
timeout=timeout,
)
@ -29,19 +30,21 @@ class ConfigClient(RequestResponse):
return resp.values[0].value
return None
async def put(self, type, key, value, timeout=CONFIG_TIMEOUT):
async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT):
"""Put a single config value."""
await self._request(
operation="put",
workspace=workspace,
values=[ConfigValue(type=type, key=key, value=value)],
timeout=timeout,
)
async def put_many(self, values, timeout=CONFIG_TIMEOUT):
"""Put multiple config values in a single request.
values is a list of (type, key, value) tuples."""
async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT):
"""Put multiple config values in a single request within a
single workspace. values is a list of (type, key, value) tuples."""
await self._request(
operation="put",
workspace=workspace,
values=[
ConfigValue(type=t, key=k, value=v)
for t, k, v in values
@ -49,19 +52,21 @@ class ConfigClient(RequestResponse):
timeout=timeout,
)
async def delete(self, type, key, timeout=CONFIG_TIMEOUT):
async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
"""Delete a single config key."""
await self._request(
operation="delete",
workspace=workspace,
keys=[ConfigKey(type=type, key=key)],
timeout=timeout,
)
async def delete_many(self, keys, timeout=CONFIG_TIMEOUT):
"""Delete multiple config keys in a single request.
keys is a list of (type, key) tuples."""
async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT):
"""Delete multiple config keys in a single request within a
single workspace. keys is a list of (type, key) tuples."""
await self._request(
operation="delete",
workspace=workspace,
keys=[
ConfigKey(type=t, key=k)
for t, k in keys
@ -69,15 +74,26 @@ class ConfigClient(RequestResponse):
timeout=timeout,
)
async def keys(self, type, timeout=CONFIG_TIMEOUT):
"""List all keys for a config type."""
async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT):
"""List all keys for a config type within a workspace."""
resp = await self._request(
operation="list",
workspace=workspace,
type=type,
timeout=timeout,
)
return resp.directory
async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT):
"""Return the set of distinct workspaces with any config of
the given type."""
resp = await self._request(
operation="getvalues-all-ws",
type=type,
timeout=timeout,
)
return {v.workspace for v in resp.values if v.workspace}
class ConfigClientSpec(RequestResponseSpec):
def __init__(

View file

@ -24,7 +24,10 @@ class ConsumerSpec(Spec):
flow = flow,
backend = processor.pubsub,
topic = definition["topics"][self.name],
subscriber = processor.id + "--" + flow.name + "--" + self.name,
subscriber = (
processor.id + "--" + flow.workspace + "--" +
flow.name + "--" + self.name
),
schema = self.schema,
handler = self.handler,
metrics = consumer_metrics,

View file

@ -9,14 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
async def query(self, vector, limit=20, collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vector = vector,
limit = limit,
user = user,
collection = collection
),
timeout=timeout

View file

@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
logger.debug(f"Handling document embeddings query request {id}...")
docs = await self.query_document_embeddings(request)
docs = await self.query_document_embeddings(
flow.workspace, request,
)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunks=docs, error=None)

View file

@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor):
request = msg.value()
await self.store_document_embeddings(request)
# Workspace comes from the flow the message arrived on.
await self.store_document_embeddings(flow.workspace, request)
except TooManyRequests as e:
raise e

View file

@ -2,7 +2,7 @@
Base class for dynamically pluggable tool services.
Tool services are Pulsar services that can be invoked as agent tools.
They receive a ToolServiceRequest with user, config, and arguments,
They receive a ToolServiceRequest with config and arguments,
and return a ToolServiceResponse with the result.
Uses direct Pulsar topics (no flow configuration required):
@ -42,7 +42,6 @@ class DynamicToolService(AsyncProcessor):
the tool's logic.
The invoke method receives:
- user: The user context for multi-tenancy
- config: Dict of config values from the tool descriptor
- arguments: Dict of arguments from the LLM
@ -115,14 +114,13 @@ class DynamicToolService(AsyncProcessor):
id = msg.properties().get("id", "unknown")
# Parse the request
user = request.user or "trustgraph"
config = json.loads(request.config) if request.config else {}
arguments = json.loads(request.arguments) if request.arguments else {}
logger.debug(f"Tool service request: user={user}, config={config}, arguments={arguments}")
logger.debug(f"Tool service request: config={config}, arguments={arguments}")
# Invoke the tool implementation
response = await self.invoke(user, config, arguments)
response = await self.invoke(config, arguments)
# Send success response
await self.producer.send(
@ -159,14 +157,13 @@ class DynamicToolService(AsyncProcessor):
properties={"id": id if id else "unknown"}
)
async def invoke(self, user, config, arguments):
async def invoke(self, config, arguments):
"""
Invoke the tool service.
Override this method in subclasses to implement the tool's logic.
Args:
user: The user context for multi-tenancy
config: Dict of config values from the tool descriptor
arguments: Dict of arguments from the LLM

View file

@ -4,15 +4,16 @@ import asyncio
class Flow:
"""
Runtime representation of a deployed flow process.
This class maintains internal processor states and orchestrates
lifecycles (start, stop) for inputs (consumers) and parameters
lifecycles (start, stop) for inputs (consumers) and parameters
that drive data flowing across linked nodes.
"""
def __init__(self, id, flow, processor, defn):
def __init__(self, id, flow, workspace, processor, defn):
self.id = id
self.name = flow
self.workspace = workspace
self.producer = {}

View file

@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor):
)
# Initialise flow information state
# Keyed by (workspace, flow) tuples; each workspace has its own
# set of flow variants for this processor.
self.flows = {}
# These can be overriden by a derived class:
@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor):
def register_specification(self, spec: Any) -> None:
self.specifications.append(spec)
# Start processing for a new flow
async def start_flow(self, flow, defn):
self.flows[flow] = Flow(self.id, flow, self, defn)
await self.flows[flow].start()
logger.info(f"Started flow: {flow}")
# Stop processing for a new flow
async def stop_flow(self, flow):
if flow in self.flows:
await self.flows[flow].stop()
del self.flows[flow]
logger.info(f"Stopped flow: {flow}")
# Start processing for a new flow within a workspace
async def start_flow(self, workspace, flow, defn):
key = (workspace, flow)
self.flows[key] = Flow(self.id, flow, workspace, self, defn)
await self.flows[key].start()
logger.info(f"Started flow: {workspace}/{flow}")
# Event handler - called for a configuration change
async def on_configure_flows(self, config, version):
# Stop processing for a flow within a workspace
async def stop_flow(self, workspace, flow):
key = (workspace, flow)
if key in self.flows:
await self.flows[key].stop()
del self.flows[key]
logger.info(f"Stopped flow: {workspace}/{flow}")
logger.info(f"Got config version {version}")
# Event handler - called for a configuration change for a single
# workspace
async def on_configure_flows(self, workspace, config, version):
logger.info(
f"Got config version {version} for workspace {workspace}"
)
config_type = f"processor:{self.id}"
@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor):
for k, v in config[config_type].items()
}
else:
logger.debug("No configuration settings for me.")
logger.debug(
f"No configuration settings for me in {workspace}."
)
flow_config = {}
# Get list of flows which should be running and are currently
# running
wanted_flows = flow_config.keys()
# This takes a copy, needed because dict gets modified by stop_flow
current_flows = list(self.flows.keys())
# Get list of flows which should be running in this workspace,
# and the list currently running in this workspace
wanted_flows = set(flow_config.keys())
current_flows = {
f for (ws, f) in self.flows.keys() if ws == workspace
}
# Start all the flows which arent currently running
for flow in wanted_flows:
if flow not in current_flows:
await self.start_flow(flow, flow_config[flow])
# Start all the flows which aren't currently running in this
# workspace
for flow in wanted_flows - current_flows:
await self.start_flow(workspace, flow, flow_config[flow])
# Stop all the unwanted flows which are due to be stopped
for flow in current_flows:
if flow not in wanted_flows:
await self.stop_flow(flow)
# Stop all the unwanted flows in this workspace
for flow in current_flows - wanted_flows:
await self.stop_flow(workspace, flow)
logger.info("Handled config update")
logger.info(f"Handled config update for workspace {workspace}")
# Start threads, just call parent
async def start(self):

View file

@ -22,14 +22,12 @@ def to_value(x: Any) -> Any:
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
async def query(self, vector, limit=20, collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vector = vector,
limit = limit,
user = user,
collection = collection
),
timeout=timeout

View file

@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor):
logger.debug(f"Handling graph embeddings query request {id}...")
entities = await self.query_graph_embeddings(request)
entities = await self.query_graph_embeddings(
flow.workspace, request,
)
logger.debug("Sending graph embeddings query response...")
r = GraphEmbeddingsResponse(entities=entities, error=None)

View file

@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor):
request = msg.value()
await self.store_graph_embeddings(request)
# Workspace comes from the flow the message arrived on.
await self.store_graph_embeddings(flow.workspace, request)
except TooManyRequests as e:
raise e

View file

@ -3,7 +3,7 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphRagQuery, GraphRagResponse
class GraphRagClient(RequestResponse):
async def rag(self, query, user="trustgraph", collection="default",
async def rag(self, query, collection="default",
chunk_callback=None, explain_callback=None,
parent_uri="",
timeout=600):
@ -12,7 +12,6 @@ class GraphRagClient(RequestResponse):
Args:
query: The question to ask
user: User identifier
collection: Collection identifier
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications
@ -49,7 +48,6 @@ class GraphRagClient(RequestResponse):
await self.request(
GraphRagQuery(
query = query,
user = user,
collection = collection,
parent_uri = parent_uri,
),

View file

@ -10,7 +10,7 @@ Usage:
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
)
await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, user)
content = await self.librarian.fetch_document_content(doc_id, workspace)
"""
import asyncio
@ -150,7 +150,7 @@ class LibrarianClient:
finally:
self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, user, timeout=120):
async def fetch_document_content(self, document_id, workspace, timeout=120):
"""Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding.
@ -158,7 +158,7 @@ class LibrarianClient:
req = LibrarianRequest(
operation="stream-document",
document_id=document_id,
user=user,
workspace=workspace,
)
chunks = await self.stream(req, timeout=timeout)
@ -176,24 +176,24 @@ class LibrarianClient:
return base64.b64encode(raw)
async def fetch_document_text(self, document_id, user, timeout=120):
async def fetch_document_text(self, document_id, workspace, timeout=120):
"""Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content(
document_id, user, timeout=timeout,
document_id, workspace, timeout=timeout,
)
return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, user, timeout=120):
async def fetch_document_metadata(self, document_id, workspace, timeout=120):
"""Fetch document metadata from the librarian."""
req = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
workspace=workspace,
)
response = await self.request(req, timeout=timeout)
return response.document_metadata
async def save_child_document(self, doc_id, parent_id, user, content,
async def save_child_document(self, doc_id, parent_id, workspace, content,
document_type="chunk", title=None,
kind="text/plain", timeout=120):
"""Save a child document to the librarian."""
@ -202,7 +202,7 @@ class LibrarianClient:
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
@ -218,7 +218,7 @@ class LibrarianClient:
await self.request(req, timeout=timeout)
return doc_id
async def save_document(self, doc_id, user, content, title=None,
async def save_document(self, doc_id, workspace, content, title=None,
document_type="answer", kind="text/plain",
timeout=120):
"""Save a document to the librarian."""
@ -227,7 +227,7 @@ class LibrarianClient:
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind=kind,
title=title or doc_id,
document_type=document_type,
@ -238,7 +238,7 @@ class LibrarianClient:
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
user=user,
workspace=workspace,
)
await self.request(req, timeout=timeout)

View file

@ -133,8 +133,9 @@ class RequestResponseSpec(Spec):
# Make subscription names unique, so that all subscribers get
# to see all response messages
subscription = (
processor.id + "--" + flow.name + "--" + self.request_name +
"--" + str(uuid.uuid4())
processor.id + "--" + flow.workspace + "--" +
flow.name + "--" + self.request_name + "--" +
str(uuid.uuid4())
),
consumer_name = flow.id,
request_topic = definition["topics"][self.request_name],

View file

@ -3,13 +3,12 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vector, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vector=vector,
schema_name=schema_name,
user=user,
collection=collection,
limit=limit
)

View file

@ -2,11 +2,10 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import StructuredQueryRequest, StructuredQueryResponse
class StructuredQueryClient(RequestResponse):
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
async def structured_query(self, question, collection="default", timeout=600):
resp = await self.request(
StructuredQueryRequest(
question = question,
user = user,
collection = collection
),
timeout=timeout

View file

@ -21,7 +21,7 @@ class SubscriberSpec(Spec):
subscriber = Subscriber(
backend = processor.pubsub,
topic = definition["topics"][self.name],
subscription = flow.id,
subscription = flow.id + "--" + flow.workspace + "--" + flow.name,
consumer_name = flow.id,
schema = self.schema,
metrics = subscriber_metrics,

View file

@ -64,6 +64,7 @@ class ToolService(FlowProcessor):
id = msg.properties()["id"]
response = await self.invoke_tool(
flow.workspace,
request.name,
json.loads(request.parameters) if request.parameters else {},
)

View file

@ -11,12 +11,11 @@ logger = logging.getLogger(__name__)
class ToolServiceClient(RequestResponse):
"""Client for invoking dynamically configured tool services."""
async def call(self, user, config, arguments, timeout=600):
async def call(self, config, arguments, timeout=600):
"""
Call a tool service.
Args:
user: User context for multi-tenancy
config: Dict of config values (e.g., {"collection": "customers"})
arguments: Dict of arguments from LLM
timeout: Request timeout in seconds
@ -26,7 +25,6 @@ class ToolServiceClient(RequestResponse):
"""
resp = await self.request(
ToolServiceRequest(
user=user,
config=json.dumps(config) if config else "{}",
arguments=json.dumps(arguments) if arguments else "{}",
),
@ -38,12 +36,11 @@ class ToolServiceClient(RequestResponse):
return resp.response
async def call_streaming(self, user, config, arguments, callback, timeout=600):
async def call_streaming(self, config, arguments, callback, timeout=600):
"""
Call a tool service with streaming response.
Args:
user: User context for multi-tenancy
config: Dict of config values
arguments: Dict of arguments from LLM
callback: Async function called with each response chunk
@ -66,7 +63,6 @@ class ToolServiceClient(RequestResponse):
await self.request(
ToolServiceRequest(
user=user,
config=json.dumps(config) if config else "{}",
arguments=json.dumps(arguments) if arguments else "{}",
),

View file

@ -45,7 +45,7 @@ def from_value(x: Any) -> Any:
class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
collection="default",
timeout=30, g=None):
resp = await self.request(
@ -54,7 +54,6 @@ class TriplesClient(RequestResponse):
p = from_value(p),
o = from_value(o),
limit = limit,
user = user,
collection = collection,
g = g,
),
@ -72,7 +71,7 @@ class TriplesClient(RequestResponse):
return triples
async def query_stream(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
collection="default",
batch_size=20, timeout=30,
batch_callback=None, g=None):
"""
@ -81,7 +80,6 @@ class TriplesClient(RequestResponse):
Args:
s, p, o: Triple pattern (None for wildcard)
limit: Maximum total triples to return
user: User/keyspace
collection: Collection name
batch_size: Triples per batch
timeout: Request timeout in seconds
@ -116,7 +114,6 @@ class TriplesClient(RequestResponse):
p=from_value(p),
o=from_value(o),
limit=limit,
user=user,
collection=collection,
streaming=True,
batch_size=batch_size,

View file

@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor):
logger.debug(f"Handling triples query request {id}...")
workspace = flow.workspace
if request.streaming:
# Streaming mode: send batches
async for batch, is_final in self.query_triples_stream(request):
async for batch, is_final in self.query_triples_stream(
workspace, request,
):
r = TriplesQueryResponse(
triples=batch,
error=None,
@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor):
logger.debug("Triples query streaming completed")
else:
# Non-streaming mode: single response
triples = await self.query_triples(request)
triples = await self.query_triples(workspace, request)
logger.debug("Sending triples query response...")
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor):
await flow("response").send(r, properties={"id": id})
async def query_triples_stream(self, request):
async def query_triples_stream(self, workspace, request):
"""
Streaming query - yields (batch, is_final) tuples.
Default implementation batches results from query_triples.
Override for true streaming from backend.
"""
triples = await self.query_triples(request)
triples = await self.query_triples(workspace, request)
batch_size = request.batch_size if request.batch_size > 0 else 20
for i in range(0, len(triples), batch_size):

View file

@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor):
request = msg.value()
await self.store_triples(request)
# Workspace is derived from the flow the message arrived on,
# not from fields in the message payload. Topic routing is
# the isolation boundary.
await self.store_triples(flow.workspace, request)
except TooManyRequests as e:
raise e