mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
chore: linting
This commit is contained in:
parent
b9a66cb417
commit
ca9bbee06d
41 changed files with 314 additions and 244 deletions
|
|
@ -90,9 +90,7 @@ class TestCompose:
|
|||
assert "<citation_instructions>" in prompt
|
||||
assert "[citation:chunk_id]" in prompt
|
||||
|
||||
def test_team_visibility_uses_team_variants(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
thread_visibility=ChatVisibility.SEARCH_SPACE,
|
||||
|
|
@ -145,9 +143,7 @@ class TestCompose:
|
|||
assert "Generate Image" in prompt
|
||||
assert "Generate Podcast" in prompt
|
||||
|
||||
def test_mcp_routing_block_emits_when_provided(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
|
||||
|
|
@ -162,9 +158,7 @@ class TestCompose:
|
|||
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
|
||||
assert "<mcp_tool_routing>" not in prompt
|
||||
|
||||
def test_provider_block_renders_when_anthropic(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
|
||||
)
|
||||
|
|
@ -267,7 +261,10 @@ class TestStableOrderingForCacheStability:
|
|||
)
|
||||
b = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
|
||||
enabled_tool_names={
|
||||
"scrape_webpage",
|
||||
"web_search",
|
||||
}, # set order shouldn't matter
|
||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||
)
|
||||
assert a == b
|
||||
|
|
|
|||
|
|
@ -83,7 +83,11 @@ class TestActionLogMiddlewareDisabled:
|
|||
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
|
||||
tool_call={
|
||||
"name": "make_widget",
|
||||
"args": {"color": "red", "size": 1},
|
||||
"id": "tc1",
|
||||
}
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
with patch_get_flags(_disabled_flags()):
|
||||
|
|
@ -117,13 +121,12 @@ class TestActionLogMiddlewarePersistence:
|
|||
"id": "tc-abc",
|
||||
},
|
||||
)
|
||||
result_msg = ToolMessage(
|
||||
content="ok", tool_call_id="tc-abc", id="msg-1"
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
|
||||
|
|
@ -151,9 +154,11 @@ class TestActionLogMiddlewarePersistence:
|
|||
)
|
||||
handler = AsyncMock(side_effect=ValueError("boom"))
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
), pytest.raises(ValueError, match="boom"):
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
pytest.raises(ValueError, match="boom"),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
assert len(captured["rows"]) == 1
|
||||
|
|
@ -177,8 +182,9 @@ class TestActionLogMiddlewarePersistence:
|
|||
def _exploding_session():
|
||||
raise RuntimeError("DB is down")
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=_exploding_session
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||
):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert result is result_msg
|
||||
|
|
@ -218,8 +224,9 @@ class TestReverseDescriptor:
|
|||
)
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
|
|
@ -257,8 +264,9 @@ class TestReverseDescriptor:
|
|||
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
|
|
@ -275,11 +283,10 @@ class TestReverseDescriptor:
|
|||
request = _FakeRequest(
|
||||
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(
|
||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
||||
)
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
|
|
@ -298,11 +305,10 @@ class TestArgsTruncation:
|
|||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
|
||||
)
|
||||
handler = AsyncMock(
|
||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
||||
)
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
|
|
|
|||
|
|
@ -26,10 +26,16 @@ class TestIsProtectedSystemMessage:
|
|||
assert _is_protected_system_message(msg) is True
|
||||
|
||||
def test_unprotected_system_message(self) -> None:
|
||||
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
|
||||
assert (
|
||||
_is_protected_system_message(SystemMessage(content="random instructions"))
|
||||
is False
|
||||
)
|
||||
|
||||
def test_human_message_never_protected(self) -> None:
|
||||
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
|
||||
assert (
|
||||
_is_protected_system_message(HumanMessage(content="<workspace_tree>..."))
|
||||
is False
|
||||
)
|
||||
|
||||
def test_tolerates_leading_whitespace(self) -> None:
|
||||
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
||||
|
|
@ -97,11 +103,17 @@ class TestPartitionMessages:
|
|||
assert protected not in to_summary
|
||||
assert protected in preserved
|
||||
# The non-protected old messages remain in to_summary
|
||||
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
|
||||
assert any(
|
||||
isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary
|
||||
)
|
||||
|
||||
def test_unprotected_messages_unaffected(self) -> None:
|
||||
partitioner = self._build_partitioner()
|
||||
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
|
||||
msgs = [
|
||||
HumanMessage(content="a"),
|
||||
HumanMessage(content="b"),
|
||||
HumanMessage(content="c"),
|
||||
]
|
||||
to_summary, preserved = partitioner._partition_messages(msgs, 2)
|
||||
assert [m.content for m in to_summary] == ["a", "b"]
|
||||
assert [m.content for m in preserved] == ["c"]
|
||||
|
|
|
|||
|
|
@ -70,7 +70,8 @@ class TestSpillEdit:
|
|||
|
||||
# Earlier ToolMessages should now contain the placeholder text
|
||||
cleared = [
|
||||
m for m in tool_messages
|
||||
m
|
||||
for m in tool_messages
|
||||
if isinstance(m.content, str) and m.content.startswith("[cleared")
|
||||
]
|
||||
assert len(cleared) >= 1
|
||||
|
|
|
|||
|
|
@ -46,9 +46,21 @@ def test_callable_dedup_key_takes_priority() -> None:
|
|||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
|
||||
{
|
||||
"name": "create_doc",
|
||||
"args": {"parent_id": "x", "title": "y"},
|
||||
"id": "1",
|
||||
},
|
||||
{
|
||||
"name": "create_doc",
|
||||
"args": {"parent_id": "x", "title": "y"},
|
||||
"id": "2",
|
||||
},
|
||||
{
|
||||
"name": "create_doc",
|
||||
"args": {"parent_id": "x", "title": "z"},
|
||||
"id": "3",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,9 +84,7 @@ class TestConnectorDenyOverridesDefaultAllow:
|
|||
Rule(permission="linear_create_issue", pattern="*", action="deny")
|
||||
]
|
||||
)
|
||||
rules = evaluate_many(
|
||||
"linear_create_issue", ["linear_create_issue"], *rulesets
|
||||
)
|
||||
rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets)
|
||||
assert aggregate_action(rules) == "deny"
|
||||
|
||||
def test_default_allow_still_applies_to_other_tools(self) -> None:
|
||||
|
|
@ -124,5 +122,7 @@ class TestUserRuleOverridesDefault:
|
|||
rules=[Rule(permission="send_*", pattern="*", action="deny")],
|
||||
origin="user",
|
||||
)
|
||||
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
|
||||
rules = evaluate_many(
|
||||
"send_gmail_email", ["send_gmail_email"], defaults, user_ruleset
|
||||
)
|
||||
assert aggregate_action(rules) == "deny"
|
||||
|
|
|
|||
|
|
@ -64,22 +64,17 @@ def test_threshold_triggers_after_n_identical_calls() -> None:
|
|||
runtime,
|
||||
)
|
||||
name = type(excinfo.value).__name__.lower()
|
||||
assert (
|
||||
"interrupt" in name
|
||||
or "runtimeerror" in name
|
||||
), f"Expected an interrupt-style exception, got {name}"
|
||||
assert "interrupt" in name or "runtimeerror" in name, (
|
||||
f"Expected an interrupt-style exception, got {name}"
|
||||
)
|
||||
|
||||
|
||||
def test_does_not_trigger_when_args_differ() -> None:
|
||||
mw = DoomLoopMiddleware(threshold=2)
|
||||
runtime = _FakeRuntime()
|
||||
out = mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
|
||||
)
|
||||
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime)
|
||||
assert out is None
|
||||
out = mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
|
||||
)
|
||||
out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime)
|
||||
assert out is None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,9 @@ class TestShouldInject:
|
|||
mw = NoopInjectionMiddleware()
|
||||
req = _FakeRequest(
|
||||
tools=[object()],
|
||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
||||
messages=[
|
||||
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
|
||||
],
|
||||
model=_LiteLLMModel(),
|
||||
)
|
||||
assert mw._should_inject(req) is False
|
||||
|
|
@ -109,7 +111,9 @@ class TestShouldInject:
|
|||
mw = NoopInjectionMiddleware()
|
||||
req = _FakeRequest(
|
||||
tools=[],
|
||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
||||
messages=[
|
||||
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])
|
||||
],
|
||||
model=_OpenAIModel(),
|
||||
)
|
||||
assert mw._should_inject(req) is False
|
||||
|
|
|
|||
|
|
@ -111,6 +111,4 @@ class TestAsk:
|
|||
assert out is None # call kept
|
||||
# Runtime ruleset got the always-allow rule
|
||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||
assert any(
|
||||
r.permission == "send_email" for r in new_rules
|
||||
)
|
||||
assert any(r.permission == "send_email" for r in new_rules)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,9 @@ class TestPluginLoaderBasics:
|
|||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[ep],
|
||||
):
|
||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
|
||||
result = load_plugin_middlewares(
|
||||
_ctx(), allowed_plugin_names=["allowed_only"]
|
||||
)
|
||||
assert result == []
|
||||
assert not called
|
||||
|
||||
|
|
@ -135,9 +137,7 @@ class TestPluginLoaderIsolation:
|
|||
_FakeEntryPoint("crashing", crashing_factory),
|
||||
_FakeEntryPoint("ok", year_substituter_factory),
|
||||
]
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
|
||||
):
|
||||
with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps):
|
||||
result = load_plugin_middlewares(
|
||||
_ctx(), allowed_plugin_names={"crashing", "ok"}
|
||||
)
|
||||
|
|
@ -151,9 +151,7 @@ class TestAllowlistEnv:
|
|||
assert load_allowed_plugin_names_from_env() == set()
|
||||
|
||||
def test_parses_comma_separated_value(self, monkeypatch) -> None:
|
||||
monkeypatch.setenv(
|
||||
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
|
||||
)
|
||||
monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ")
|
||||
assert load_allowed_plugin_names_from_env() == {
|
||||
"year_substituter",
|
||||
"noisy",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class _FakeResponse:
|
|||
self.headers = headers
|
||||
|
||||
|
||||
class _FakeRateLimit(Exception):
|
||||
class _FakeRateLimitError(Exception):
|
||||
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
|
||||
super().__init__(msg)
|
||||
if headers is not None:
|
||||
|
|
@ -27,15 +27,15 @@ class _FakeRateLimit(Exception):
|
|||
|
||||
class TestExtractRetryAfter:
|
||||
def test_seconds_header(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
|
||||
exc = _FakeRateLimitError("rate", {"Retry-After": "30"})
|
||||
assert _extract_retry_after_seconds(exc) == 30.0
|
||||
|
||||
def test_milliseconds_header_overrides_seconds(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
|
||||
exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"})
|
||||
assert _extract_retry_after_seconds(exc) == 1.5
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
|
||||
exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"})
|
||||
assert _extract_retry_after_seconds(exc) == 12.0
|
||||
|
||||
def test_falls_back_to_message_regex(self) -> None:
|
||||
|
|
@ -67,7 +67,7 @@ class TestIsNonRetryable:
|
|||
class TestDelayCalculation:
|
||||
def test_takes_max_of_backoff_and_header(self) -> None:
|
||||
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
|
||||
exc = _FakeRateLimit("rl", {"retry-after": "10"})
|
||||
exc = _FakeRateLimitError("rl", {"retry-after": "10"})
|
||||
delay = mw._delay_for_attempt(0, exc)
|
||||
assert delay == pytest.approx(10.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -122,7 +122,9 @@ class TestExploreSubagent:
|
|||
def test_includes_permission_middleware_with_deny_rules(self) -> None:
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
m
|
||||
for m in spec["middleware"]
|
||||
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
assert len(permission_mws) == 1
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
|
|
@ -164,7 +166,9 @@ class TestReportWriterSubagent:
|
|||
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
|
||||
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
m
|
||||
for m in spec["middleware"]
|
||||
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||
|
|
@ -194,17 +198,15 @@ class TestConnectorNegotiatorSubagent:
|
|||
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
|
||||
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
m
|
||||
for m in spec["middleware"]
|
||||
if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||
# `linear_create_issue` matches the `*_create` deny pattern.
|
||||
assert any(
|
||||
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
|
||||
)
|
||||
assert any(
|
||||
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
|
||||
)
|
||||
assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns)
|
||||
assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns)
|
||||
|
||||
|
||||
class TestBuildSpecializedSubagents:
|
||||
|
|
@ -242,8 +244,7 @@ class TestBuildSpecializedSubagents:
|
|||
# order: extra → custom → patch → dedup.
|
||||
sentinel_idx = mws.index(sentinel)
|
||||
perm_idx = next(
|
||||
(i for i, m in enumerate(mws)
|
||||
if isinstance(m, PermissionMiddleware)),
|
||||
(i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)),
|
||||
None,
|
||||
)
|
||||
assert perm_idx is not None
|
||||
|
|
@ -259,7 +260,9 @@ class TestFilterToolsWarningSuppression:
|
|||
|
||||
from app.agents.new_chat.subagents.config import _filter_tools
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
|
||||
with caplog.at_level(
|
||||
logging.INFO, logger="app.agents.new_chat.subagents.config"
|
||||
):
|
||||
# Allowed set asks for two registry tools (one present, one
|
||||
# not) plus a bunch of middleware-provided names.
|
||||
_filter_tools(
|
||||
|
|
@ -275,9 +278,7 @@ class TestFilterToolsWarningSuppression:
|
|||
},
|
||||
)
|
||||
|
||||
warnings = [
|
||||
r.message for r in caplog.records if r.levelno >= logging.INFO
|
||||
]
|
||||
warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO]
|
||||
# Exactly one warning, and it should mention scrape_webpage but not
|
||||
# any middleware-provided name. Inspect the rendered "missing"
|
||||
# list (between the brackets) so we don't false-match substrings
|
||||
|
|
|
|||
|
|
@ -27,9 +27,12 @@ class TestRepair:
|
|||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "echo", "args": {}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "echo", "args": {}, "id": "1"},
|
||||
],
|
||||
)
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is None # no change
|
||||
|
||||
|
|
@ -37,9 +40,12 @@ class TestRepair:
|
|||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "Echo", "args": {"x": 1}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "Echo", "args": {"x": 1}, "id": "1"},
|
||||
],
|
||||
)
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
repaired = out["messages"][0]
|
||||
|
|
@ -50,9 +56,12 @@ class TestRepair:
|
|||
registered_tool_names={"echo", INVALID_TOOL_NAME},
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
|
||||
],
|
||||
)
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
repaired_call = out["messages"][0].tool_calls[0]
|
||||
|
|
@ -64,9 +73,12 @@ class TestRepair:
|
|||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "unknown", "args": {}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "unknown", "args": {}, "id": "1"},
|
||||
],
|
||||
)
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
# No repair available; original returned unchanged (no update)
|
||||
assert out is None
|
||||
|
|
@ -76,9 +88,12 @@ class TestRepair:
|
|||
registered_tool_names={"search_documents"},
|
||||
fuzzy_match_threshold=0.7,
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "search_docments", "args": {}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "search_docments", "args": {}, "id": "1"},
|
||||
],
|
||||
)
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
|
||||
|
|
@ -94,9 +109,12 @@ class TestRepair:
|
|||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "DynamicTool", "args": {}, "id": "1"},
|
||||
])
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "DynamicTool", "args": {}, "id": "1"},
|
||||
],
|
||||
)
|
||||
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
|
||||
out = mw.after_model(_make_state(msg), runtime)
|
||||
assert out is not None
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy.
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ class _FakeAction:
|
|||
class TestCanRevert:
|
||||
def test_owner_can_revert_their_own_action(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
assert can_revert(
|
||||
requester_user_id="user-123", action=action, is_admin=False
|
||||
)
|
||||
assert can_revert(requester_user_id="user-123", action=action, is_admin=False)
|
||||
|
||||
def test_other_user_cannot_revert(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
|
|
@ -28,21 +26,15 @@ class TestCanRevert:
|
|||
|
||||
def test_admin_always_allowed(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
assert can_revert(
|
||||
requester_user_id="anybody", action=action, is_admin=True
|
||||
)
|
||||
assert can_revert(requester_user_id="anybody", action=action, is_admin=True)
|
||||
|
||||
def test_admin_can_revert_anonymous_action(self) -> None:
|
||||
action = _FakeAction(user_id=None)
|
||||
assert can_revert(
|
||||
requester_user_id="admin", action=action, is_admin=True
|
||||
)
|
||||
assert can_revert(requester_user_id="admin", action=action, is_admin=True)
|
||||
|
||||
def test_anonymous_action_blocks_non_admin(self) -> None:
|
||||
action = _FakeAction(user_id=None)
|
||||
assert not can_revert(
|
||||
requester_user_id="user-1", action=action, is_admin=False
|
||||
)
|
||||
assert not can_revert(requester_user_id="user-1", action=action, is_admin=False)
|
||||
|
||||
def test_uuid_string_normalization(self) -> None:
|
||||
"""``user_id`` may be a UUID object; comparison should still work."""
|
||||
|
|
@ -51,6 +43,4 @@ class TestCanRevert:
|
|||
u = uuid.uuid4()
|
||||
action = _FakeAction(user_id=u)
|
||||
# Same UUID, passed as string from the requesting side.
|
||||
assert can_revert(
|
||||
requester_user_id=str(u), action=action, is_admin=False
|
||||
)
|
||||
assert can_revert(requester_user_id=str(u), action=action, is_admin=False)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue