trustgraph/tests/unit/test_agent/test_tool_service_lifecycle.py
2026-04-22 15:19:57 +01:00

545 lines
18 KiB
Python

"""
Tests for tool service lifecycle, invoke contract, streaming responses,
and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
ToolServiceRequest, ToolServiceResponse, Error,
ToolRequest, ToolResponse,
)
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# DynamicToolService tests
# ---------------------------------------------------------------------------
class TestDynamicToolServiceInvokeContract:
@pytest.mark.asyncio
async def test_base_invoke_raises_not_implemented(self):
"""Base class invoke() should raise NotImplementedError."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke({}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
"""on_request should JSON-parse config/arguments and pass to invoke."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
calls = []
async def tracking_invoke(config, arguments):
calls.append({"config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
# Ensure the class-level metric exists
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
msg.properties.return_value = {"id": "req-1"}
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(config, arguments):
return "hello world"
svc.invoke = string_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert isinstance(sent, ToolServiceResponse)
assert sent.response == "hello world"
assert sent.end_of_stream is True
assert sent.error is None
@pytest.mark.asyncio
async def test_on_request_dict_response_json_encoded(self):
"""Dict return from invoke → response field is JSON-encoded."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert json.loads(sent.response) == {"result": 42}
@pytest.mark.asyncio
async def test_on_request_error_sends_error_response(self):
"""Exception in invoke → error response sent."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-service-error"
assert "bad input" in sent.error.message
assert sent.response == ""
@pytest.mark.asyncio
async def test_on_request_too_many_requests_propagates(self):
"""TooManyRequests should propagate (not caught as error response)."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), None)
@pytest.mark.asyncio
async def test_on_request_preserves_message_id(self):
"""Response should include the original message id in properties."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(config, arguments):
return "ok"
svc.invoke = ok_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
props = svc.producer.send.call_args[1]["properties"]
assert props["id"] == "unique-42"
# ---------------------------------------------------------------------------
# ToolService (flow-based) tests
# ---------------------------------------------------------------------------
class TestToolServiceOnRequest:
@pytest.mark.asyncio
async def test_string_response_sent_as_text(self):
"""String return from invoke_tool → ToolResponse.text is set."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
flow.name = "test-flow"
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
msg.properties.return_value = {"id": "t1"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert isinstance(sent, ToolResponse)
assert sent.text == "tool result"
assert sent.object is None
@pytest.mark.asyncio
async def test_dict_response_sent_as_json_object(self):
"""Dict return from invoke_tool → ToolResponse.object is JSON."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t2"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.text is None
assert json.loads(sent.object) == {"data": [1, 2, 3]}
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
"""Exception in invoke_tool → error response via flow producer."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t3"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-error"
assert "tool broke" in sent.error.message
@pytest.mark.asyncio
async def test_too_many_requests_propagates(self):
"""TooManyRequests should propagate from ToolService.on_request."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t4"}
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_parameters_json_parsed(self):
"""Parameters should be JSON-parsed before passing to invoke_tool."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
received = {}
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
svc.invoke_tool = capture_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(
name="search",
parameters='{"query": "test", "limit": 10}',
)
msg.properties.return_value = {"id": "t5"}
await svc.on_request(msg, MagicMock(), flow)
assert received["name"] == "search"
assert received["params"] == {"query": "test", "limit": 10}
# ---------------------------------------------------------------------------
# ToolServiceClient tests
# ---------------------------------------------------------------------------
class TestToolServiceClientCall:
@pytest.mark.asyncio
async def test_call_sends_request_and_returns_response(self):
"""call() should send ToolServiceRequest and return response string."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="joke result", end_of_stream=True,
))
result = await client.call(
config={"style": "pun"},
arguments={"topic": "cats"},
)
assert result == "joke result"
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@pytest.mark.asyncio
async def test_call_raises_on_error(self):
"""call() should raise RuntimeError when response has error."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=Error(type="tool-service-error", message="service down"),
response="",
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
"""Empty config/arguments should be sent as '{}'."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
assert req.arguments == "{}"
@pytest.mark.asyncio
async def test_call_passes_timeout(self):
"""call() should forward timeout to underlying request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
class TestToolServiceClientStreaming:
@pytest.mark.asyncio
async def test_call_streaming_collects_chunks(self):
"""call_streaming should accumulate chunks and return full result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
# Simulate streaming: request() calls recipient with each chunk
chunks = [
ToolServiceResponse(error=None, response="chunk1", end_of_stream=False),
ToolServiceResponse(error=None, response="chunk2", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
assert received == ["chunk1", "chunk2"]
@pytest.mark.asyncio
async def test_call_streaming_raises_on_error(self):
"""call_streaming should raise RuntimeError on error chunk."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
async def mock_request(req, timeout=600, recipient=None):
error_resp = ToolServiceResponse(
error=Error(type="tool-service-error", message="stream failed"),
response="",
end_of_stream=True,
)
await recipient(error_resp)
client.request = mock_request
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
config={}, arguments={},
callback=AsyncMock(),
)
@pytest.mark.asyncio
async def test_call_streaming_skips_empty_response(self):
"""Empty response chunks should not be added to result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
chunks = [
ToolServiceResponse(error=None, response="", end_of_stream=False),
ToolServiceResponse(error=None, response="data", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]