mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-24 14:18:05 +02:00
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.
Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
captures the workspace/collection/flow hierarchy.
Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
service layer.
- Translators updated to not serialise/deserialise user.
API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.
Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
scoped by workspace. Config client API takes workspace as first
positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.
CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
library) drop user kwargs from every method signature.
MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
keyed per user.
Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
whose blueprint template was parameterised AND no remaining
live flow (across all workspaces) still resolves to that topic.
Three scopes fall out naturally from template analysis:
* {id} -> per-flow, deleted on stop
* {blueprint} -> per-blueprint, kept while any flow of the
same blueprint exists
* {workspace} -> per-workspace, kept while any flow in the
workspace exists
* literal -> global, never deleted (e.g. tg.request.librarian)
Fixes a bug where stopping a flow silently destroyed the global
librarian exchange, wedging all library operations until manual
restart.
RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
dead connections (broker restart, orphaned channels, network
partitions) within ~2 heartbeat windows, so the consumer
reconnects and re-binds its queue rather than sitting forever
on a zombie connection.
Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
parent
9332089b3d
commit
d35473f7f7
377 changed files with 6868 additions and 5785 deletions
|
|
@ -125,21 +125,39 @@ class AsyncProcessor:
|
|||
response_metrics = config_resp_metrics,
|
||||
)
|
||||
|
||||
async def fetch_config(self):
|
||||
"""Fetch full config from config service using a short-lived
|
||||
request/response client. Returns (config, version) or raises."""
|
||||
client = self._create_config_client()
|
||||
try:
|
||||
await client.start()
|
||||
resp = await client.request(
|
||||
ConfigRequest(operation="config"),
|
||||
timeout=10,
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(f"Config error: {resp.error.message}")
|
||||
return resp.config, resp.version
|
||||
finally:
|
||||
await client.stop()
|
||||
async def _fetch_type_workspace(self, client, workspace, config_type):
|
||||
"""Fetch config values of a single type within one workspace.
|
||||
Returns dict of {key: value}."""
|
||||
resp = await client.request(
|
||||
ConfigRequest(
|
||||
operation="getvalues",
|
||||
workspace=workspace,
|
||||
type=config_type,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(f"Config error: {resp.error.message}")
|
||||
return {v.key: v.value for v in resp.values}
|
||||
|
||||
async def _fetch_type_all_workspaces(self, client, config_type):
|
||||
"""Fetch config values of a single type across all workspaces.
|
||||
Returns dict of {workspace: {key: value}}."""
|
||||
resp = await client.request(
|
||||
ConfigRequest(
|
||||
operation="getvalues-all-ws",
|
||||
type=config_type,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(f"Config error: {resp.error.message}")
|
||||
|
||||
grouped = {}
|
||||
for v in resp.values:
|
||||
ws = grouped.setdefault(v.workspace, {})
|
||||
ws[v.key] = v.value
|
||||
return grouped, resp.version
|
||||
|
||||
# This is called to start dynamic behaviour.
|
||||
# Implements the subscribe-then-fetch pattern to avoid race conditions.
|
||||
|
|
@ -155,21 +173,51 @@ class AsyncProcessor:
|
|||
# processed by on_config_notify, which does the version check
|
||||
|
||||
async def fetch_and_apply_config(self):
|
||||
"""Fetch full config from config service and apply to all handlers.
|
||||
Retries until successful — config service may not be ready yet."""
|
||||
"""Startup: for each registered handler, fetch config for all its
|
||||
types across all workspaces and invoke the handler once per
|
||||
workspace. Retries until successful — config service may not be
|
||||
ready yet."""
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
config, version = await self.fetch_config()
|
||||
client = self._create_config_client()
|
||||
try:
|
||||
await client.start()
|
||||
|
||||
logger.info(f"Fetched config version {version}")
|
||||
version = 0
|
||||
|
||||
self.config_version = version
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
|
||||
# Apply to all handlers (startup = invoke all)
|
||||
for entry in self.config_handlers:
|
||||
await entry["handler"](config, version)
|
||||
# Handlers registered without types get nothing
|
||||
# at startup (there is no "all types" fetch).
|
||||
if not handler_types:
|
||||
continue
|
||||
|
||||
# Group all registered types by workspace:
|
||||
# {workspace: {type: {key: value}}}
|
||||
per_ws = {}
|
||||
for t in handler_types:
|
||||
type_data, v = \
|
||||
await self._fetch_type_all_workspaces(
|
||||
client, t,
|
||||
)
|
||||
version = max(version, v)
|
||||
for ws, kv in type_data.items():
|
||||
per_ws.setdefault(ws, {})[t] = kv
|
||||
|
||||
# Call the handler once per workspace
|
||||
for ws, config in per_ws.items():
|
||||
await entry["handler"](ws, config, version)
|
||||
|
||||
logger.info(
|
||||
f"Applied startup config version {version}"
|
||||
)
|
||||
self.config_version = version
|
||||
|
||||
finally:
|
||||
await client.stop()
|
||||
|
||||
return
|
||||
|
||||
|
|
@ -204,8 +252,9 @@ class AsyncProcessor:
|
|||
# Called when a config notify message arrives
|
||||
async def on_config_notify(self, message, consumer, flow):
|
||||
|
||||
notify_version = message.value().version
|
||||
notify_types = set(message.value().types)
|
||||
v = message.value()
|
||||
notify_version = v.version
|
||||
changes = v.changes # dict of type -> [workspaces]
|
||||
|
||||
# Skip if we already have this version or newer
|
||||
if notify_version <= self.config_version:
|
||||
|
|
@ -215,41 +264,60 @@ class AsyncProcessor:
|
|||
)
|
||||
return
|
||||
|
||||
# Check if any handler cares about the affected types
|
||||
if notify_types:
|
||||
any_interested = False
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types is None or notify_types & handler_types:
|
||||
any_interested = True
|
||||
break
|
||||
notify_types = set(changes.keys())
|
||||
|
||||
if not any_interested:
|
||||
logger.debug(
|
||||
f"Ignoring config notify v{notify_version}, "
|
||||
f"no handlers for types {notify_types}"
|
||||
)
|
||||
self.config_version = notify_version
|
||||
return
|
||||
# Filter out handlers that don't care about any of the changed
|
||||
# types. A handler registered without types never fires on
|
||||
# notifications (nothing to scope to).
|
||||
interested = []
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types and notify_types & handler_types:
|
||||
interested.append(entry)
|
||||
|
||||
if not interested:
|
||||
logger.debug(
|
||||
f"Ignoring config notify v{notify_version}, "
|
||||
f"no handlers for types {notify_types}"
|
||||
)
|
||||
self.config_version = notify_version
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Config notify v{notify_version} types={list(notify_types)}, "
|
||||
f"fetching config..."
|
||||
f"Config notify v{notify_version} "
|
||||
f"types={list(notify_types)}, fetching config..."
|
||||
)
|
||||
|
||||
# Fetch full config using short-lived client
|
||||
try:
|
||||
config, version = await self.fetch_config()
|
||||
client = self._create_config_client()
|
||||
try:
|
||||
await client.start()
|
||||
|
||||
self.config_version = version
|
||||
for entry in interested:
|
||||
handler_types = entry["types"]
|
||||
|
||||
# Invoke handlers that care about the affected types
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types is None:
|
||||
await entry["handler"](config, version)
|
||||
elif not notify_types or notify_types & handler_types:
|
||||
await entry["handler"](config, version)
|
||||
# Build {workspace: {type: {key: value}}} for types
|
||||
# this handler cares about, where the workspace was
|
||||
# affected for that type.
|
||||
per_ws = {}
|
||||
for t in handler_types:
|
||||
if t not in changes:
|
||||
continue
|
||||
for ws in changes[t]:
|
||||
kv = await self._fetch_type_workspace(
|
||||
client, ws, t,
|
||||
)
|
||||
per_ws.setdefault(ws, {})[t] = kv
|
||||
|
||||
for ws, config in per_ws.items():
|
||||
await entry["handler"](
|
||||
ws, config, notify_version,
|
||||
)
|
||||
|
||||
finally:
|
||||
await client.stop()
|
||||
|
||||
self.config_version = notify_version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor):
|
|||
await super(ChunkingService, self).start()
|
||||
await self.librarian.start()
|
||||
|
||||
async def get_document_text(self, doc):
|
||||
async def get_document_text(self, doc, workspace):
|
||||
"""
|
||||
Get text content from a TextDocument, fetching from librarian if needed.
|
||||
|
||||
Args:
|
||||
doc: TextDocument with either inline text or document_id
|
||||
workspace: Workspace for librarian lookup (from flow.workspace)
|
||||
|
||||
Returns:
|
||||
str: The document text content
|
||||
|
|
@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor):
|
|||
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
||||
text = await self.librarian.fetch_document_text(
|
||||
document_id=doc.document_id,
|
||||
user=doc.metadata.user,
|
||||
workspace=workspace,
|
||||
)
|
||||
logger.info(f"Fetched {len(text)} characters from librarian")
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -15,114 +15,139 @@ class CollectionConfigHandler:
|
|||
Storage services should:
|
||||
1. Inherit from this class along with their service base class
|
||||
2. Call register_config_handler(self.on_collection_config) in __init__
|
||||
3. Implement create_collection(user, collection, metadata) method
|
||||
4. Implement delete_collection(user, collection) method
|
||||
3. Implement create_collection(workspace, collection, metadata) method
|
||||
4. Implement delete_collection(workspace, collection) method
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Track known collections: {(user, collection): metadata_dict}
|
||||
# Track known collections: {(workspace, collection): metadata_dict}
|
||||
self.known_collections: Dict[tuple, dict] = {}
|
||||
# Pass remaining kwargs up the inheritance chain
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def on_collection_config(self, config: dict, version: int):
|
||||
async def on_collection_config(
|
||||
self, workspace: str, config: dict, version: int
|
||||
):
|
||||
"""
|
||||
Handle config push messages and extract collection information
|
||||
for a single workspace.
|
||||
|
||||
Args:
|
||||
workspace: Workspace the config applies to
|
||||
config: Configuration dictionary from ConfigPush message
|
||||
version: Configuration version number
|
||||
"""
|
||||
logger.info(f"Processing collection configuration (version {version})")
|
||||
logger.info(
|
||||
f"Processing collection configuration "
|
||||
f"(version {version}, workspace {workspace})"
|
||||
)
|
||||
|
||||
# Extract collections from config (treat missing key as empty)
|
||||
# Extract collections from config (treat missing key as empty).
|
||||
# Each config key IS the collection name — config is already
|
||||
# partitioned by workspace, so no workspace prefix is needed
|
||||
# on the key.
|
||||
collection_config = config.get("collection", {})
|
||||
|
||||
# Track which collections we've seen in this config
|
||||
current_collections: Set[tuple] = set()
|
||||
|
||||
# Process each collection in the config
|
||||
for key, value_json in collection_config.items():
|
||||
for collection, value_json in collection_config.items():
|
||||
try:
|
||||
# Parse user:collection key
|
||||
if ":" not in key:
|
||||
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
|
||||
continue
|
||||
current_collections.add((workspace, collection))
|
||||
|
||||
user, collection = key.split(":", 1)
|
||||
current_collections.add((user, collection))
|
||||
|
||||
# Parse metadata
|
||||
metadata = json.loads(value_json)
|
||||
|
||||
# Check if this is a new collection or updated
|
||||
collection_key = (user, collection)
|
||||
if collection_key not in self.known_collections:
|
||||
logger.info(f"New collection detected: {user}/{collection}")
|
||||
await self.create_collection(user, collection, metadata)
|
||||
self.known_collections[collection_key] = metadata
|
||||
key = (workspace, collection)
|
||||
if key not in self.known_collections:
|
||||
logger.info(
|
||||
f"New collection detected: {workspace}/{collection}"
|
||||
)
|
||||
await self.create_collection(
|
||||
workspace, collection, metadata
|
||||
)
|
||||
self.known_collections[key] = metadata
|
||||
else:
|
||||
# Collection already exists, update metadata if changed
|
||||
if self.known_collections[collection_key] != metadata:
|
||||
logger.info(f"Collection metadata updated: {user}/{collection}")
|
||||
# Most storage services don't need to do anything for metadata updates
|
||||
# They just need to know the collection exists
|
||||
self.known_collections[collection_key] = metadata
|
||||
if self.known_collections[key] != metadata:
|
||||
logger.info(
|
||||
f"Collection metadata updated: "
|
||||
f"{workspace}/{collection}"
|
||||
)
|
||||
self.known_collections[key] = metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error processing collection config for "
|
||||
f"{workspace}/{collection}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Find collections that were deleted (in known but not in current)
|
||||
deleted_collections = set(self.known_collections.keys()) - current_collections
|
||||
for user, collection in deleted_collections:
|
||||
logger.info(f"Collection deleted: {user}/{collection}")
|
||||
# Find collections for THIS workspace that were deleted (in
|
||||
# known but not in current). Only compare collections owned by
|
||||
# this workspace — other workspaces' collections are not
|
||||
# affected by this config update.
|
||||
known_for_ws = {
|
||||
(w, c) for (w, c) in self.known_collections.keys()
|
||||
if w == workspace
|
||||
}
|
||||
deleted_collections = known_for_ws - current_collections
|
||||
for ws, collection in deleted_collections:
|
||||
logger.info(f"Collection deleted: {ws}/{collection}")
|
||||
try:
|
||||
# Remove from known_collections FIRST to immediately reject new writes
|
||||
# This eliminates race condition with worker threads
|
||||
del self.known_collections[(user, collection)]
|
||||
# Physical deletion happens after - worker threads already rejecting writes
|
||||
await self.delete_collection(user, collection)
|
||||
# Remove from known_collections FIRST to immediately
|
||||
# reject new writes
|
||||
del self.known_collections[(ws, collection)]
|
||||
await self.delete_collection(ws, collection)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
|
||||
# If physical deletion failed, should we re-add to known_collections?
|
||||
# For now, keep it removed - collection is logically deleted per config
|
||||
logger.error(
|
||||
f"Error deleting collection {ws}/{collection}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
|
||||
logger.debug(
|
||||
f"Collection config processing complete. "
|
||||
f"Known collections: {len(self.known_collections)}"
|
||||
)
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
async def create_collection(
|
||||
self, workspace: str, collection: str, metadata: dict,
|
||||
):
|
||||
"""
|
||||
Create a collection in the storage backend.
|
||||
|
||||
Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
workspace: Workspace ID
|
||||
collection: Collection ID
|
||||
metadata: Collection metadata dictionary
|
||||
"""
|
||||
raise NotImplementedError("Storage service must implement create_collection method")
|
||||
raise NotImplementedError(
|
||||
"Storage service must implement create_collection method"
|
||||
)
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
async def delete_collection(self, workspace: str, collection: str):
|
||||
"""
|
||||
Delete a collection from the storage backend.
|
||||
|
||||
Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
workspace: Workspace ID
|
||||
collection: Collection ID
|
||||
"""
|
||||
raise NotImplementedError("Storage service must implement delete_collection method")
|
||||
raise NotImplementedError(
|
||||
"Storage service must implement delete_collection method"
|
||||
)
|
||||
|
||||
def collection_exists(self, user: str, collection: str) -> bool:
|
||||
def collection_exists(self, workspace: str, collection: str) -> bool:
|
||||
"""
|
||||
Check if a collection is known to exist
|
||||
Check if a collection is known to exist.
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
workspace: Workspace ID
|
||||
collection: Collection ID
|
||||
|
||||
Returns:
|
||||
True if collection exists, False otherwise
|
||||
"""
|
||||
return (user, collection) in self.known_collections
|
||||
return (workspace, collection) in self.known_collections
|
||||
|
|
|
|||
|
|
@ -18,10 +18,11 @@ class ConfigClient(RequestResponse):
|
|||
)
|
||||
return resp
|
||||
|
||||
async def get(self, type, key, timeout=CONFIG_TIMEOUT):
|
||||
async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
|
||||
"""Get a single config value. Returns the value string or None."""
|
||||
resp = await self._request(
|
||||
operation="get",
|
||||
workspace=workspace,
|
||||
keys=[ConfigKey(type=type, key=key)],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
|
@ -29,19 +30,21 @@ class ConfigClient(RequestResponse):
|
|||
return resp.values[0].value
|
||||
return None
|
||||
|
||||
async def put(self, type, key, value, timeout=CONFIG_TIMEOUT):
|
||||
async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT):
|
||||
"""Put a single config value."""
|
||||
await self._request(
|
||||
operation="put",
|
||||
workspace=workspace,
|
||||
values=[ConfigValue(type=type, key=key, value=value)],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def put_many(self, values, timeout=CONFIG_TIMEOUT):
|
||||
"""Put multiple config values in a single request.
|
||||
values is a list of (type, key, value) tuples."""
|
||||
async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT):
|
||||
"""Put multiple config values in a single request within a
|
||||
single workspace. values is a list of (type, key, value) tuples."""
|
||||
await self._request(
|
||||
operation="put",
|
||||
workspace=workspace,
|
||||
values=[
|
||||
ConfigValue(type=t, key=k, value=v)
|
||||
for t, k, v in values
|
||||
|
|
@ -49,19 +52,21 @@ class ConfigClient(RequestResponse):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def delete(self, type, key, timeout=CONFIG_TIMEOUT):
|
||||
async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT):
|
||||
"""Delete a single config key."""
|
||||
await self._request(
|
||||
operation="delete",
|
||||
workspace=workspace,
|
||||
keys=[ConfigKey(type=type, key=key)],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def delete_many(self, keys, timeout=CONFIG_TIMEOUT):
|
||||
"""Delete multiple config keys in a single request.
|
||||
keys is a list of (type, key) tuples."""
|
||||
async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT):
|
||||
"""Delete multiple config keys in a single request within a
|
||||
single workspace. keys is a list of (type, key) tuples."""
|
||||
await self._request(
|
||||
operation="delete",
|
||||
workspace=workspace,
|
||||
keys=[
|
||||
ConfigKey(type=t, key=k)
|
||||
for t, k in keys
|
||||
|
|
@ -69,15 +74,26 @@ class ConfigClient(RequestResponse):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def keys(self, type, timeout=CONFIG_TIMEOUT):
|
||||
"""List all keys for a config type."""
|
||||
async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT):
|
||||
"""List all keys for a config type within a workspace."""
|
||||
resp = await self._request(
|
||||
operation="list",
|
||||
workspace=workspace,
|
||||
type=type,
|
||||
timeout=timeout,
|
||||
)
|
||||
return resp.directory
|
||||
|
||||
async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT):
|
||||
"""Return the set of distinct workspaces with any config of
|
||||
the given type."""
|
||||
resp = await self._request(
|
||||
operation="getvalues-all-ws",
|
||||
type=type,
|
||||
timeout=timeout,
|
||||
)
|
||||
return {v.workspace for v in resp.values if v.workspace}
|
||||
|
||||
|
||||
class ConfigClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,10 @@ class ConsumerSpec(Spec):
|
|||
flow = flow,
|
||||
backend = processor.pubsub,
|
||||
topic = definition["topics"][self.name],
|
||||
subscriber = processor.id + "--" + flow.name + "--" + self.name,
|
||||
subscriber = (
|
||||
processor.id + "--" + flow.workspace + "--" +
|
||||
flow.name + "--" + self.name
|
||||
),
|
||||
schema = self.schema,
|
||||
handler = self.handler,
|
||||
metrics = consumer_metrics,
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from .. knowledge import Uri, Literal
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
async def query(self, vector, limit=20, collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
DocumentEmbeddingsRequest(
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
),
|
||||
timeout=timeout
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling document embeddings query request {id}...")
|
||||
|
||||
docs = await self.query_document_embeddings(request)
|
||||
docs = await self.query_document_embeddings(
|
||||
flow.workspace, request,
|
||||
)
|
||||
|
||||
logger.debug("Sending document embeddings query response...")
|
||||
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor):
|
|||
|
||||
request = msg.value()
|
||||
|
||||
await self.store_document_embeddings(request)
|
||||
# Workspace comes from the flow the message arrived on.
|
||||
await self.store_document_embeddings(flow.workspace, request)
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Base class for dynamically pluggable tool services.
|
||||
|
||||
Tool services are Pulsar services that can be invoked as agent tools.
|
||||
They receive a ToolServiceRequest with user, config, and arguments,
|
||||
They receive a ToolServiceRequest with config and arguments,
|
||||
and return a ToolServiceResponse with the result.
|
||||
|
||||
Uses direct Pulsar topics (no flow configuration required):
|
||||
|
|
@ -42,7 +42,6 @@ class DynamicToolService(AsyncProcessor):
|
|||
the tool's logic.
|
||||
|
||||
The invoke method receives:
|
||||
- user: The user context for multi-tenancy
|
||||
- config: Dict of config values from the tool descriptor
|
||||
- arguments: Dict of arguments from the LLM
|
||||
|
||||
|
|
@ -115,14 +114,13 @@ class DynamicToolService(AsyncProcessor):
|
|||
id = msg.properties().get("id", "unknown")
|
||||
|
||||
# Parse the request
|
||||
user = request.user or "trustgraph"
|
||||
config = json.loads(request.config) if request.config else {}
|
||||
arguments = json.loads(request.arguments) if request.arguments else {}
|
||||
|
||||
logger.debug(f"Tool service request: user={user}, config={config}, arguments={arguments}")
|
||||
logger.debug(f"Tool service request: config={config}, arguments={arguments}")
|
||||
|
||||
# Invoke the tool implementation
|
||||
response = await self.invoke(user, config, arguments)
|
||||
response = await self.invoke(config, arguments)
|
||||
|
||||
# Send success response
|
||||
await self.producer.send(
|
||||
|
|
@ -159,14 +157,13 @@ class DynamicToolService(AsyncProcessor):
|
|||
properties={"id": id if id else "unknown"}
|
||||
)
|
||||
|
||||
async def invoke(self, user, config, arguments):
|
||||
async def invoke(self, config, arguments):
|
||||
"""
|
||||
Invoke the tool service.
|
||||
|
||||
Override this method in subclasses to implement the tool's logic.
|
||||
|
||||
Args:
|
||||
user: The user context for multi-tenancy
|
||||
config: Dict of config values from the tool descriptor
|
||||
arguments: Dict of arguments from the LLM
|
||||
|
||||
|
|
|
|||
|
|
@ -4,15 +4,16 @@ import asyncio
|
|||
class Flow:
|
||||
"""
|
||||
Runtime representation of a deployed flow process.
|
||||
|
||||
|
||||
This class maintains internal processor states and orchestrates
|
||||
lifecycles (start, stop) for inputs (consumers) and parameters
|
||||
lifecycles (start, stop) for inputs (consumers) and parameters
|
||||
that drive data flowing across linked nodes.
|
||||
"""
|
||||
def __init__(self, id, flow, processor, defn):
|
||||
def __init__(self, id, flow, workspace, processor, defn):
|
||||
|
||||
self.id = id
|
||||
self.name = flow
|
||||
self.workspace = workspace
|
||||
|
||||
self.producer = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor):
|
|||
)
|
||||
|
||||
# Initialise flow information state
|
||||
# Keyed by (workspace, flow) tuples; each workspace has its own
|
||||
# set of flow variants for this processor.
|
||||
self.flows = {}
|
||||
|
||||
# These can be overriden by a derived class:
|
||||
|
|
@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor):
|
|||
def register_specification(self, spec: Any) -> None:
|
||||
self.specifications.append(spec)
|
||||
|
||||
# Start processing for a new flow
|
||||
async def start_flow(self, flow, defn):
|
||||
self.flows[flow] = Flow(self.id, flow, self, defn)
|
||||
await self.flows[flow].start()
|
||||
logger.info(f"Started flow: {flow}")
|
||||
|
||||
# Stop processing for a new flow
|
||||
async def stop_flow(self, flow):
|
||||
if flow in self.flows:
|
||||
await self.flows[flow].stop()
|
||||
del self.flows[flow]
|
||||
logger.info(f"Stopped flow: {flow}")
|
||||
# Start processing for a new flow within a workspace
|
||||
async def start_flow(self, workspace, flow, defn):
|
||||
key = (workspace, flow)
|
||||
self.flows[key] = Flow(self.id, flow, workspace, self, defn)
|
||||
await self.flows[key].start()
|
||||
logger.info(f"Started flow: {workspace}/{flow}")
|
||||
|
||||
# Event handler - called for a configuration change
|
||||
async def on_configure_flows(self, config, version):
|
||||
# Stop processing for a flow within a workspace
|
||||
async def stop_flow(self, workspace, flow):
|
||||
key = (workspace, flow)
|
||||
if key in self.flows:
|
||||
await self.flows[key].stop()
|
||||
del self.flows[key]
|
||||
logger.info(f"Stopped flow: {workspace}/{flow}")
|
||||
|
||||
logger.info(f"Got config version {version}")
|
||||
# Event handler - called for a configuration change for a single
|
||||
# workspace
|
||||
async def on_configure_flows(self, workspace, config, version):
|
||||
|
||||
logger.info(
|
||||
f"Got config version {version} for workspace {workspace}"
|
||||
)
|
||||
|
||||
config_type = f"processor:{self.id}"
|
||||
|
||||
|
|
@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor):
|
|||
for k, v in config[config_type].items()
|
||||
}
|
||||
else:
|
||||
logger.debug("No configuration settings for me.")
|
||||
logger.debug(
|
||||
f"No configuration settings for me in {workspace}."
|
||||
)
|
||||
flow_config = {}
|
||||
|
||||
# Get list of flows which should be running and are currently
|
||||
# running
|
||||
wanted_flows = flow_config.keys()
|
||||
# This takes a copy, needed because dict gets modified by stop_flow
|
||||
current_flows = list(self.flows.keys())
|
||||
# Get list of flows which should be running in this workspace,
|
||||
# and the list currently running in this workspace
|
||||
wanted_flows = set(flow_config.keys())
|
||||
current_flows = {
|
||||
f for (ws, f) in self.flows.keys() if ws == workspace
|
||||
}
|
||||
|
||||
# Start all the flows which arent currently running
|
||||
for flow in wanted_flows:
|
||||
if flow not in current_flows:
|
||||
await self.start_flow(flow, flow_config[flow])
|
||||
# Start all the flows which aren't currently running in this
|
||||
# workspace
|
||||
for flow in wanted_flows - current_flows:
|
||||
await self.start_flow(workspace, flow, flow_config[flow])
|
||||
|
||||
# Stop all the unwanted flows which are due to be stopped
|
||||
for flow in current_flows:
|
||||
if flow not in wanted_flows:
|
||||
await self.stop_flow(flow)
|
||||
# Stop all the unwanted flows in this workspace
|
||||
for flow in current_flows - wanted_flows:
|
||||
await self.stop_flow(workspace, flow)
|
||||
|
||||
logger.info("Handled config update")
|
||||
logger.info(f"Handled config update for workspace {workspace}")
|
||||
|
||||
# Start threads, just call parent
|
||||
async def start(self):
|
||||
|
|
|
|||
|
|
@ -22,14 +22,12 @@ def to_value(x: Any) -> Any:
|
|||
return Literal(x.value or x.iri)
|
||||
|
||||
class GraphEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
async def query(self, vector, limit=20, collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
GraphEmbeddingsRequest(
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
),
|
||||
timeout=timeout
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling graph embeddings query request {id}...")
|
||||
|
||||
entities = await self.query_graph_embeddings(request)
|
||||
entities = await self.query_graph_embeddings(
|
||||
flow.workspace, request,
|
||||
)
|
||||
|
||||
logger.debug("Sending graph embeddings query response...")
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor):
|
|||
|
||||
request = msg.value()
|
||||
|
||||
await self.store_graph_embeddings(request)
|
||||
# Workspace comes from the flow the message arrived on.
|
||||
await self.store_graph_embeddings(flow.workspace, request)
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import GraphRagQuery, GraphRagResponse
|
||||
|
||||
class GraphRagClient(RequestResponse):
|
||||
async def rag(self, query, user="trustgraph", collection="default",
|
||||
async def rag(self, query, collection="default",
|
||||
chunk_callback=None, explain_callback=None,
|
||||
parent_uri="",
|
||||
timeout=600):
|
||||
|
|
@ -12,7 +12,6 @@ class GraphRagClient(RequestResponse):
|
|||
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications
|
||||
|
|
@ -49,7 +48,6 @@ class GraphRagClient(RequestResponse):
|
|||
await self.request(
|
||||
GraphRagQuery(
|
||||
query = query,
|
||||
user = user,
|
||||
collection = collection,
|
||||
parent_uri = parent_uri,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ Usage:
|
|||
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
||||
)
|
||||
await self.librarian.start()
|
||||
content = await self.librarian.fetch_document_content(doc_id, user)
|
||||
content = await self.librarian.fetch_document_content(doc_id, workspace)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -150,7 +150,7 @@ class LibrarianClient:
|
|||
finally:
|
||||
self._streams.pop(request_id, None)
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
async def fetch_document_content(self, document_id, workspace, timeout=120):
|
||||
"""Fetch document content using streaming.
|
||||
|
||||
Returns base64-encoded content. Caller is responsible for decoding.
|
||||
|
|
@ -158,7 +158,7 @@ class LibrarianClient:
|
|||
req = LibrarianRequest(
|
||||
operation="stream-document",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
workspace=workspace,
|
||||
)
|
||||
chunks = await self.stream(req, timeout=timeout)
|
||||
|
||||
|
|
@ -176,24 +176,24 @@ class LibrarianClient:
|
|||
|
||||
return base64.b64encode(raw)
|
||||
|
||||
async def fetch_document_text(self, document_id, user, timeout=120):
|
||||
async def fetch_document_text(self, document_id, workspace, timeout=120):
|
||||
"""Fetch document content and decode as UTF-8 text."""
|
||||
content = await self.fetch_document_content(
|
||||
document_id, user, timeout=timeout,
|
||||
document_id, workspace, timeout=timeout,
|
||||
)
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
async def fetch_document_metadata(self, document_id, workspace, timeout=120):
|
||||
"""Fetch document metadata from the librarian."""
|
||||
req = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
workspace=workspace,
|
||||
)
|
||||
response = await self.request(req, timeout=timeout)
|
||||
return response.document_metadata
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
async def save_child_document(self, doc_id, parent_id, workspace, content,
|
||||
document_type="chunk", title=None,
|
||||
kind="text/plain", timeout=120):
|
||||
"""Save a child document to the librarian."""
|
||||
|
|
@ -202,7 +202,7 @@ class LibrarianClient:
|
|||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
workspace=workspace,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
|
|
@ -218,7 +218,7 @@ class LibrarianClient:
|
|||
await self.request(req, timeout=timeout)
|
||||
return doc_id
|
||||
|
||||
async def save_document(self, doc_id, user, content, title=None,
|
||||
async def save_document(self, doc_id, workspace, content, title=None,
|
||||
document_type="answer", kind="text/plain",
|
||||
timeout=120):
|
||||
"""Save a document to the librarian."""
|
||||
|
|
@ -227,7 +227,7 @@ class LibrarianClient:
|
|||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
workspace=workspace,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
document_type=document_type,
|
||||
|
|
@ -238,7 +238,7 @@ class LibrarianClient:
|
|||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
user=user,
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
await self.request(req, timeout=timeout)
|
||||
|
|
|
|||
|
|
@ -133,8 +133,9 @@ class RequestResponseSpec(Spec):
|
|||
# Make subscription names unique, so that all subscribers get
|
||||
# to see all response messages
|
||||
subscription = (
|
||||
processor.id + "--" + flow.name + "--" + self.request_name +
|
||||
"--" + str(uuid.uuid4())
|
||||
processor.id + "--" + flow.workspace + "--" +
|
||||
flow.name + "--" + self.request_name + "--" +
|
||||
str(uuid.uuid4())
|
||||
),
|
||||
consumer_name = flow.id,
|
||||
request_topic = definition["topics"][self.request_name],
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
|||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vector, schema_name, user="trustgraph", collection="default",
|
||||
self, vector, schema_name, collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vector=vector,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
|
||||
class StructuredQueryClient(RequestResponse):
|
||||
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
|
||||
async def structured_query(self, question, collection="default", timeout=600):
|
||||
resp = await self.request(
|
||||
StructuredQueryRequest(
|
||||
question = question,
|
||||
user = user,
|
||||
collection = collection
|
||||
),
|
||||
timeout=timeout
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class SubscriberSpec(Spec):
|
|||
subscriber = Subscriber(
|
||||
backend = processor.pubsub,
|
||||
topic = definition["topics"][self.name],
|
||||
subscription = flow.id,
|
||||
subscription = flow.id + "--" + flow.workspace + "--" + flow.name,
|
||||
consumer_name = flow.id,
|
||||
schema = self.schema,
|
||||
metrics = subscriber_metrics,
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class ToolService(FlowProcessor):
|
|||
id = msg.properties()["id"]
|
||||
|
||||
response = await self.invoke_tool(
|
||||
flow.workspace,
|
||||
request.name,
|
||||
json.loads(request.parameters) if request.parameters else {},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,12 +11,11 @@ logger = logging.getLogger(__name__)
|
|||
class ToolServiceClient(RequestResponse):
|
||||
"""Client for invoking dynamically configured tool services."""
|
||||
|
||||
async def call(self, user, config, arguments, timeout=600):
|
||||
async def call(self, config, arguments, timeout=600):
|
||||
"""
|
||||
Call a tool service.
|
||||
|
||||
Args:
|
||||
user: User context for multi-tenancy
|
||||
config: Dict of config values (e.g., {"collection": "customers"})
|
||||
arguments: Dict of arguments from LLM
|
||||
timeout: Request timeout in seconds
|
||||
|
|
@ -26,7 +25,6 @@ class ToolServiceClient(RequestResponse):
|
|||
"""
|
||||
resp = await self.request(
|
||||
ToolServiceRequest(
|
||||
user=user,
|
||||
config=json.dumps(config) if config else "{}",
|
||||
arguments=json.dumps(arguments) if arguments else "{}",
|
||||
),
|
||||
|
|
@ -38,12 +36,11 @@ class ToolServiceClient(RequestResponse):
|
|||
|
||||
return resp.response
|
||||
|
||||
async def call_streaming(self, user, config, arguments, callback, timeout=600):
|
||||
async def call_streaming(self, config, arguments, callback, timeout=600):
|
||||
"""
|
||||
Call a tool service with streaming response.
|
||||
|
||||
Args:
|
||||
user: User context for multi-tenancy
|
||||
config: Dict of config values
|
||||
arguments: Dict of arguments from LLM
|
||||
callback: Async function called with each response chunk
|
||||
|
|
@ -66,7 +63,6 @@ class ToolServiceClient(RequestResponse):
|
|||
|
||||
await self.request(
|
||||
ToolServiceRequest(
|
||||
user=user,
|
||||
config=json.dumps(config) if config else "{}",
|
||||
arguments=json.dumps(arguments) if arguments else "{}",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def from_value(x: Any) -> Any:
|
|||
|
||||
class TriplesClient(RequestResponse):
|
||||
async def query(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
collection="default",
|
||||
timeout=30, g=None):
|
||||
|
||||
resp = await self.request(
|
||||
|
|
@ -54,7 +54,6 @@ class TriplesClient(RequestResponse):
|
|||
p = from_value(p),
|
||||
o = from_value(o),
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection,
|
||||
g = g,
|
||||
),
|
||||
|
|
@ -72,7 +71,7 @@ class TriplesClient(RequestResponse):
|
|||
return triples
|
||||
|
||||
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
collection="default",
|
||||
batch_size=20, timeout=30,
|
||||
batch_callback=None, g=None):
|
||||
"""
|
||||
|
|
@ -81,7 +80,6 @@ class TriplesClient(RequestResponse):
|
|||
Args:
|
||||
s, p, o: Triple pattern (None for wildcard)
|
||||
limit: Maximum total triples to return
|
||||
user: User/keyspace
|
||||
collection: Collection name
|
||||
batch_size: Triples per batch
|
||||
timeout: Request timeout in seconds
|
||||
|
|
@ -116,7 +114,6 @@ class TriplesClient(RequestResponse):
|
|||
p=from_value(p),
|
||||
o=from_value(o),
|
||||
limit=limit,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
batch_size=batch_size,
|
||||
|
|
|
|||
|
|
@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling triples query request {id}...")
|
||||
|
||||
workspace = flow.workspace
|
||||
|
||||
if request.streaming:
|
||||
# Streaming mode: send batches
|
||||
async for batch, is_final in self.query_triples_stream(request):
|
||||
async for batch, is_final in self.query_triples_stream(
|
||||
workspace, request,
|
||||
):
|
||||
r = TriplesQueryResponse(
|
||||
triples=batch,
|
||||
error=None,
|
||||
|
|
@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor):
|
|||
logger.debug("Triples query streaming completed")
|
||||
else:
|
||||
# Non-streaming mode: single response
|
||||
triples = await self.query_triples(request)
|
||||
triples = await self.query_triples(workspace, request)
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
async def query_triples_stream(self, request):
|
||||
async def query_triples_stream(self, workspace, request):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Default implementation batches results from query_triples.
|
||||
Override for true streaming from backend.
|
||||
"""
|
||||
triples = await self.query_triples(request)
|
||||
triples = await self.query_triples(workspace, request)
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
|
|
|
|||
|
|
@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor):
|
|||
|
||||
request = msg.value()
|
||||
|
||||
await self.store_triples(request)
|
||||
# Workspace is derived from the flow the message arrived on,
|
||||
# not from fields in the message payload. Topic routing is
|
||||
# the isolation boundary.
|
||||
await self.store_triples(flow.workspace, request)
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue