mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
133 lines
3.8 KiB
Python
133 lines
3.8 KiB
Python
|
|
"""
|
||
|
|
Tests for the on_action callback in react() — verifies that it fires
|
||
|
|
after action selection but before tool execution.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock
|
||
|
|
|
||
|
|
from trustgraph.agent.react.agent_manager import AgentManager
|
||
|
|
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||
|
|
|
||
|
|
|
||
|
|
class TestOnActionCallback:
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_on_action_called_for_tool_use(self):
|
||
|
|
"""on_action fires when react() selects a tool (not Final)."""
|
||
|
|
call_log = []
|
||
|
|
|
||
|
|
async def fake_on_action(act):
|
||
|
|
call_log.append(("on_action", act.name))
|
||
|
|
|
||
|
|
# Tool that records when it's invoked
|
||
|
|
async def tool_invoke(**kwargs):
|
||
|
|
call_log.append(("tool_invoke",))
|
||
|
|
return "tool result"
|
||
|
|
|
||
|
|
tool_impl = MagicMock()
|
||
|
|
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||
|
|
|
||
|
|
tools = {
|
||
|
|
"search": Tool(
|
||
|
|
name="search",
|
||
|
|
description="Search",
|
||
|
|
implementation=tool_impl,
|
||
|
|
arguments=[Argument(name="query", type="string", description="q")],
|
||
|
|
config={},
|
||
|
|
),
|
||
|
|
}
|
||
|
|
|
||
|
|
agent = AgentManager(tools=tools)
|
||
|
|
|
||
|
|
# Mock reason() to return an Action
|
||
|
|
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
|
||
|
|
agent.reason = AsyncMock(return_value=action)
|
||
|
|
|
||
|
|
think = AsyncMock()
|
||
|
|
observe = AsyncMock()
|
||
|
|
context = MagicMock()
|
||
|
|
|
||
|
|
await agent.react(
|
||
|
|
question="test",
|
||
|
|
history=[],
|
||
|
|
think=think,
|
||
|
|
observe=observe,
|
||
|
|
context=context,
|
||
|
|
on_action=fake_on_action,
|
||
|
|
)
|
||
|
|
|
||
|
|
# on_action should fire before tool_invoke
|
||
|
|
assert len(call_log) == 2
|
||
|
|
assert call_log[0] == ("on_action", "search")
|
||
|
|
assert call_log[1] == ("tool_invoke",)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_on_action_not_called_for_final(self):
|
||
|
|
"""on_action does not fire when react() returns Final."""
|
||
|
|
called = []
|
||
|
|
|
||
|
|
async def fake_on_action(act):
|
||
|
|
called.append(act)
|
||
|
|
|
||
|
|
agent = AgentManager(tools={})
|
||
|
|
agent.reason = AsyncMock(
|
||
|
|
return_value=Final(thought="done", final="answer")
|
||
|
|
)
|
||
|
|
|
||
|
|
think = AsyncMock()
|
||
|
|
observe = AsyncMock()
|
||
|
|
context = MagicMock()
|
||
|
|
|
||
|
|
result = await agent.react(
|
||
|
|
question="test",
|
||
|
|
history=[],
|
||
|
|
think=think,
|
||
|
|
observe=observe,
|
||
|
|
context=context,
|
||
|
|
on_action=fake_on_action,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert isinstance(result, Final)
|
||
|
|
assert len(called) == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_on_action_none_accepted(self):
|
||
|
|
"""react() works fine when on_action is None (default)."""
|
||
|
|
async def tool_invoke(**kwargs):
|
||
|
|
return "result"
|
||
|
|
|
||
|
|
tool_impl = MagicMock()
|
||
|
|
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||
|
|
|
||
|
|
tools = {
|
||
|
|
"search": Tool(
|
||
|
|
name="search",
|
||
|
|
description="Search",
|
||
|
|
implementation=tool_impl,
|
||
|
|
arguments=[],
|
||
|
|
config={},
|
||
|
|
),
|
||
|
|
}
|
||
|
|
|
||
|
|
agent = AgentManager(tools=tools)
|
||
|
|
agent.reason = AsyncMock(
|
||
|
|
return_value=Action(thought="t", name="search", arguments={}, observation="")
|
||
|
|
)
|
||
|
|
|
||
|
|
think = AsyncMock()
|
||
|
|
observe = AsyncMock()
|
||
|
|
context = MagicMock()
|
||
|
|
|
||
|
|
result = await agent.react(
|
||
|
|
question="test",
|
||
|
|
history=[],
|
||
|
|
think=think,
|
||
|
|
observe=observe,
|
||
|
|
context=context,
|
||
|
|
# on_action not passed — defaults to None
|
||
|
|
)
|
||
|
|
|
||
|
|
assert isinstance(result, Action)
|
||
|
|
assert result.observation == "result"
|