mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 09:26:22 +02:00
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.
Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
captures the workspace/collection/flow hierarchy.
Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
service layer.
- Translators updated to not serialise/deserialise user.
API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.
Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
scoped by workspace. Config client API takes workspace as first
positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.
CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
library) drop user kwargs from every method signature.
MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
keyed per user.
Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
whose blueprint template was parameterised AND no remaining
live flow (across all workspaces) still resolves to that topic.
Three scopes fall out naturally from template analysis:
* {id} -> per-flow, deleted on stop
* {blueprint} -> per-blueprint, kept while any flow of the
same blueprint exists
* {workspace} -> per-workspace, kept while any flow in the
workspace exists
* literal -> global, never deleted (e.g. tg.request.librarian)
Fixes a bug where stopping a flow silently destroyed the global
librarian exchange, wedging all library operations until manual
restart.
RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
dead connections (broker restart, orphaned channels, network
partitions) within ~2 heartbeat windows, so the consumer
reconnects and re-binds its queue rather than sitting forever
on a zombie connection.
Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
parent
9332089b3d
commit
d35473f7f7
377 changed files with 6868 additions and 5785 deletions
|
|
@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
|
|||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
def _make_request(question="Test question",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
|
|
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
|
|||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
|
|
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
|
|
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
|
|||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
"unknown", "question", "default",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
|
|||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
|
|||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ def make_base_request(**kwargs):
|
|||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class MockProcessor:
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -167,39 +167,28 @@ class TestToolServiceRequest:
|
|||
"""Test cases for tool service request format"""
|
||||
|
||||
def test_request_format(self):
|
||||
"""Test that request is properly formatted with user, config, and arguments"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
"""Test that request is properly formatted with config and arguments"""
|
||||
config_values = {"style": "pun", "collection": "jokes"}
|
||||
arguments = {"topic": "programming"}
|
||||
|
||||
# Act - simulate request building
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values),
|
||||
"arguments": json.dumps(arguments)
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["user"] == "alice"
|
||||
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
|
||||
assert json.loads(request["arguments"]) == {"topic": "programming"}
|
||||
|
||||
def test_request_with_empty_config(self):
|
||||
"""Test request when no config values are provided"""
|
||||
# Arrange
|
||||
user = "bob"
|
||||
config_values = {}
|
||||
arguments = {"query": "test"}
|
||||
|
||||
# Act
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values) if config_values else "{}",
|
||||
"arguments": json.dumps(arguments) if arguments else "{}"
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["config"] == "{}"
|
||||
assert json.loads(request["arguments"]) == {"query": "test"}
|
||||
|
||||
|
|
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
|
|||
assert map_topic_to_category("random topic") == "default"
|
||||
assert map_topic_to_category("") == "default"
|
||||
|
||||
def test_joke_response_personalization(self):
|
||||
"""Test that joke responses include user personalization"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
def test_joke_response_format(self):
|
||||
"""Test that joke response is formatted as expected"""
|
||||
style = "pun"
|
||||
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
|
||||
|
||||
# Act
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
response = f"Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
# Assert
|
||||
assert "Hey alice!" in response
|
||||
assert "pun" in response
|
||||
assert joke in response
|
||||
|
||||
|
|
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
|
|||
|
||||
def test_request_parsing(self):
|
||||
"""Test parsing of incoming request"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
"user": "alice",
|
||||
"config": '{"style": "pun"}',
|
||||
"arguments": '{"topic": "programming"}'
|
||||
}
|
||||
|
||||
# Act
|
||||
user = request_data.get("user", "trustgraph")
|
||||
config = json.loads(request_data["config"]) if request_data["config"] else {}
|
||||
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
|
||||
|
||||
# Assert
|
||||
assert user == "alice"
|
||||
assert config == {"style": "pun"}
|
||||
assert arguments == {"topic": "programming"}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Tests for tool service lifecycle, invoke contract, streaming responses,
|
||||
multi-tenancy, and error propagation.
|
||||
and error propagation.
|
||||
|
||||
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
|
||||
classes rather than plain dicts.
|
||||
|
|
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await svc.invoke("user", {}, {})
|
||||
await svc.invoke({}, {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_calls_invoke_with_parsed_args(self):
|
||||
|
|
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
calls = []
|
||||
|
||||
async def tracking_invoke(user, config, arguments):
|
||||
calls.append({"user": user, "config": config, "arguments": arguments})
|
||||
async def tracking_invoke(config, arguments):
|
||||
calls.append({"config": config, "arguments": arguments})
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking_invoke
|
||||
|
|
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user="alice",
|
||||
config='{"style": "pun"}',
|
||||
arguments='{"topic": "cats"}',
|
||||
)
|
||||
|
|
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
|
|||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["user"] == "alice"
|
||||
assert calls[0]["config"] == {"style": "pun"}
|
||||
assert calls[0]["arguments"] == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_empty_user_defaults_to_trustgraph(self):
|
||||
"""Empty user field should default to 'trustgraph'."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
received_user = None
|
||||
|
||||
async def capture_invoke(user, config, arguments):
|
||||
nonlocal received_user
|
||||
received_user = user
|
||||
return "ok"
|
||||
|
||||
svc.invoke = capture_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
|
||||
msg.properties.return_value = {"id": "req-2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert received_user == "trustgraph"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_string_response_sent_directly(self):
|
||||
"""String return from invoke → response field is the string."""
|
||||
|
|
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def string_invoke(user, config, arguments):
|
||||
async def string_invoke(config, arguments):
|
||||
return "hello world"
|
||||
|
||||
svc.invoke = string_invoke
|
||||
|
|
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def dict_invoke(user, config, arguments):
|
||||
async def dict_invoke(config, arguments):
|
||||
return {"result": 42}
|
||||
|
||||
svc.invoke = dict_invoke
|
||||
|
|
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def failing_invoke(user, config, arguments):
|
||||
async def failing_invoke(config, arguments):
|
||||
raise ValueError("bad input")
|
||||
|
||||
svc.invoke = failing_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def rate_limited_invoke(user, config, arguments):
|
||||
async def rate_limited_invoke(config, arguments):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
svc.invoke = rate_limited_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r4"}
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
|
|
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def ok_invoke(user, config, arguments):
|
||||
async def ok_invoke(config, arguments):
|
||||
return "ok"
|
||||
|
||||
svc.invoke = ok_invoke
|
||||
|
|
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "unique-42"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
|
|
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
async def failing_invoke(workspace, name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
|
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
async def rate_limited(workspace, name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
|
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
|
|||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
flow.workspace = "default"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
|
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
|
|||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
async def capture_invoke(workspace, name, params):
|
||||
received["workspace"] = workspace
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
|
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
|
|||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
flow.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
|
|
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
result = await client.call(
|
||||
user="alice",
|
||||
config={"style": "pun"},
|
||||
arguments={"topic": "cats"},
|
||||
)
|
||||
|
|
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
|
|||
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, ToolServiceRequest)
|
||||
assert req.user == "alice"
|
||||
assert json.loads(req.config) == {"style": "pun"}
|
||||
assert json.loads(req.arguments) == {"topic": "cats"}
|
||||
|
||||
|
|
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await client.call(user="u", config={}, arguments={})
|
||||
await client.call(config={}, arguments={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_empty_config_sends_empty_json(self):
|
||||
|
|
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config=None, arguments=None)
|
||||
await client.call(config=None, arguments=None)
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.config == "{}"
|
||||
|
|
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config={}, arguments={}, timeout=30)
|
||||
await client.call(config={}, arguments={}, timeout=30)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 30
|
||||
|
|
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
assert result == "chunk1chunk2"
|
||||
|
|
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
|
|||
|
||||
with pytest.raises(RuntimeError, match="stream failed"):
|
||||
await client.call_streaming(
|
||||
user="u", config={}, arguments={},
|
||||
config={}, arguments={},
|
||||
callback=AsyncMock(),
|
||||
)
|
||||
|
||||
|
|
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
# Empty response is falsy, so callback shouldn't be called for it
|
||||
assert result == "data"
|
||||
assert received == ["data"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tenancy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiTenancy:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_propagated_to_invoke(self):
|
||||
"""User from request should reach the invoke method."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
users_seen = []
|
||||
|
||||
async def tracking(user, config, arguments):
|
||||
users_seen.append(user)
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
for u in ["tenant-a", "tenant-b", "tenant-c"]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user=u, config="{}", arguments="{}",
|
||||
)
|
||||
msg.properties.return_value = {"id": f"req-{u}"}
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_sends_user_in_request(self):
|
||||
"""ToolServiceClient.call should include user in request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="isolated-tenant", config={}, arguments={})
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.user == "isolated-tenant"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue