feat: updated agent harness

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 09:22:19 -07:00
parent 9ec9b64348
commit 31a372bb84
139 changed files with 12583 additions and 1111 deletions

View file

@ -0,0 +1 @@

View file

@ -0,0 +1 @@

View file

@ -0,0 +1 @@
"""__init__ stub so pytest discovers the prompts test module."""

View file

@ -0,0 +1,201 @@
"""Tests for the prompt fragment composer (Tier 3a)."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.agents.new_chat.prompts.composer import (
ALL_TOOL_NAMES_ORDERED,
compose_system_prompt,
detect_provider_variant,
)
from app.db import ChatVisibility
pytestmark = pytest.mark.unit
@pytest.fixture
def fixed_today() -> datetime:
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
class TestProviderVariantDetection:
@pytest.mark.parametrize(
"model_name,expected",
[
("openai:gpt-4o-mini", "openai_classic"),
("openai:gpt-4-turbo", "openai_classic"),
("openai:gpt-5", "openai_reasoning"),
("openai:gpt-5-codex", "openai_reasoning"),
("openai:o1-preview", "openai_reasoning"),
("openai:o3-mini", "openai_reasoning"),
("anthropic:claude-3-5-sonnet", "anthropic"),
("anthropic/claude-opus-4", "anthropic"),
("google:gemini-2.0-flash", "google"),
("vertex:gemini-1.5-pro", "google"),
("groq:mixtral-8x7b", "default"),
(None, "default"),
("", "default"),
],
)
def test_detection(self, model_name: str | None, expected: str) -> None:
assert detect_provider_variant(model_name) == expected
class TestCompose:
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today)
# System instruction wrapper
assert "<system_instruction>" in prompt
assert "</system_instruction>" in prompt
# Date interpolated
assert "2025-06-01" in prompt
# Core policy blocks present
assert "<knowledge_base_only_policy>" in prompt
assert "<tool_routing>" in prompt
assert "<parameter_resolution>" in prompt
assert "<memory_protocol>" in prompt
# Tools
assert "<tools>" in prompt
assert "</tools>" in prompt
# Citations on by default
assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE,
)
# Team-specific phrasing in the agent block
assert "team space" in prompt
# Memory protocol mentions team
assert "team" in prompt
# Should NOT mention the user-only memory phrasing
assert "personal knowledge base" not in prompt
def test_private_visibility_uses_private_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.PRIVATE,
)
assert "personal knowledge base" in prompt
# Should NOT mention the team-specific phrasing about prefixed authors
assert "[DisplayName of the author]" not in prompt
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
assert "Citations are DISABLED" in prompt_off
assert "Citations are DISABLED" not in prompt_on
assert "[citation:chunk_id]" in prompt_on
def test_enabled_tool_filter_only_includes_listed_tools(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
)
assert "web_search:" in prompt or "- web_search:" in prompt
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
# Excluded tools should NOT appear in tool listing
assert "generate_podcast:" not in prompt
assert "generate_image:" not in prompt
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search"},
disabled_tool_names={"generate_image", "generate_podcast"},
)
assert "DISABLED TOOLS (by user):" in prompt
assert "Generate Image" in prompt
assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
)
assert "<mcp_tool_routing>" in prompt
assert "My GitLab" in prompt
assert "gitlab_search" in prompt
def test_mcp_routing_block_absent_when_no_servers(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
)
assert "<provider_hints>" in prompt
assert "Anthropic" in prompt or "Claude" in prompt
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
assert "<provider_hints>" not in prompt
def test_custom_system_instructions_override_default(
self, fixed_today: datetime
) -> None:
custom = "You are a custom assistant. Today is {resolved_today}."
prompt = compose_system_prompt(
today=fixed_today, custom_system_instructions=custom
)
assert "You are a custom assistant. Today is 2025-06-01." in prompt
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
use_default_system_instructions=False,
)
# No system_instruction wrapper but tools/citations still emitted
assert "<system_instruction>" not in prompt
assert "<tools>" in prompt
def test_all_known_tools_have_fragments(self) -> None:
# Soft assertion: verify that every tool in the canonical order
# produces non-empty content for at least one variant.
for tool in ALL_TOOL_NAMES_ORDERED:
prompt = compose_system_prompt(
today=datetime(2025, 1, 1, tzinfo=UTC),
enabled_tool_names={tool},
)
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
class TestStableOrderingForCacheStability:
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
def test_composition_is_deterministic_given_same_inputs(
self, fixed_today: datetime
) -> None:
a = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
b = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
assert a == b

View file

@ -0,0 +1,311 @@
"""Unit tests for ActionLogMiddleware (Tier 5.2)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.tools.registry import ToolDefinition
@dataclass
class _FakeRequest:
"""Minimal stand-in for ToolCallRequest used in unit tests."""
tool_call: dict[str, Any]
tool: Any = None
state: Any = None
runtime: Any = None
@tool
def make_widget(color: str, size: int) -> str:
"""Create a widget."""
return f"made {color} {size}"
def _enabled_flags(**overrides: bool) -> AgentFeatureFlags:
return AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
**overrides,
)
def _disabled_flags() -> AgentFeatureFlags:
return AgentFeatureFlags(disable_new_agent_stack=False, enable_action_log=False)
@pytest.fixture
def patch_get_flags():
def _patch(flags: AgentFeatureFlags):
return patch(
"app.agents.new_chat.middleware.action_log.get_flags",
return_value=flags,
)
return _patch
@pytest.fixture
def fake_session_factory():
"""Patch ``shielded_async_session`` with a recording fake."""
captured: dict[str, list] = {"rows": []}
class _FakeSession:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def add(self, row):
captured["rows"].append(row)
async def commit(self):
captured["committed"] = True
def _factory():
return _FakeSession()
return captured, _factory
class TestActionLogMiddlewareDisabled:
@pytest.mark.asyncio
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_disabled_flags()):
result = await mw.awrap_tool_call(request, handler)
handler.assert_awaited_once()
assert isinstance(result, ToolMessage)
@pytest.mark.asyncio
async def test_no_op_when_thread_id_none(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=None, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_enabled_flags()):
result = await mw.awrap_tool_call(request, handler)
assert isinstance(result, ToolMessage)
class TestActionLogMiddlewarePersistence:
@pytest.mark.asyncio
async def test_writes_row_on_success(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "red", "size": 3},
"id": "tc-abc",
},
)
result_msg = ToolMessage(
content="ok", tool_call_id="tc-abc", id="msg-1"
)
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
result = await mw.awrap_tool_call(request, handler)
assert result is result_msg
assert len(captured["rows"]) == 1
row = captured["rows"][0]
assert row.thread_id == 42
assert row.search_space_id == 7
assert row.user_id == "u1"
assert row.tool_name == "make_widget"
assert row.args == {"color": "red", "size": 3}
assert row.result_id == "msg-1"
assert row.error is None
assert row.reverse_descriptor is None
assert row.reversible is False
@pytest.mark.asyncio
async def test_writes_row_on_failure_and_reraises(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red"}, "id": "tc1"}
)
handler = AsyncMock(side_effect=ValueError("boom"))
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
), pytest.raises(ValueError, match="boom"):
await mw.awrap_tool_call(request, handler)
assert len(captured["rows"]) == 1
row = captured["rows"][0]
assert row.tool_name == "make_widget"
assert row.error == {"type": "ValueError", "message": "boom"}
assert row.result_id is None
@pytest.mark.asyncio
async def test_persistence_failure_does_not_break_tool_call(
self, patch_get_flags
) -> None:
"""Even if the DB write blows up, the tool's result must reach the model."""
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg)
def _exploding_session():
raise RuntimeError("DB is down")
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=_exploding_session
):
result = await mw.awrap_tool_call(request, handler)
assert result is result_msg
class TestReverseDescriptor:
@pytest.mark.asyncio
async def test_renders_reverse_descriptor_when_tool_declares_one(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
def _reverse(args, result):
return {"tool": "delete_widget", "args": {"id": result["id"]}}
tool_def = ToolDefinition(
name="make_widget",
description="Create a widget",
factory=lambda deps: make_widget,
reverse=_reverse,
)
mw = ActionLogMiddleware(
thread_id=1,
search_space_id=1,
user_id="u",
tool_definitions={"make_widget": tool_def},
)
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "blue", "size": 1},
"id": "tc-xyz",
},
)
result_msg = ToolMessage(
content='{"id": "widget-9"}', tool_call_id="tc-xyz", id="msg-9"
)
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is True
assert row.reverse_descriptor == {
"tool": "delete_widget",
"args": {"id": "widget-9"},
}
@pytest.mark.asyncio
async def test_swallows_reverse_callable_errors(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
def _bad_reverse(args, result):
raise RuntimeError("reverse blew up")
tool_def = ToolDefinition(
name="make_widget",
description="Create a widget",
factory=lambda deps: make_widget,
reverse=_bad_reverse,
)
mw = ActionLogMiddleware(
thread_id=1,
search_space_id=1,
user_id=None,
tool_definitions={"make_widget": tool_def},
)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is False
assert row.reverse_descriptor is None
@pytest.mark.asyncio
async def test_no_reverse_when_tool_definition_missing(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is False
class TestArgsTruncation:
@pytest.mark.asyncio
async def test_huge_args_payload_is_truncated(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
# Build a > 32KB string so the persisted payload triggers the truncation path.
huge = "x" * (40 * 1024)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.args is not None
assert row.args.get("_truncated") is True
assert row.args.get("_size", 0) >= 40 * 1024

View file

@ -0,0 +1,90 @@
"""Tests for BusyMutexMiddleware: per-thread lock + cancel event behavior."""
from __future__ import annotations
import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
get_cancel_event,
manager,
request_cancel,
reset_cancel,
)
pytestmark = pytest.mark.unit
class _Runtime:
def __init__(self, thread_id: str | None) -> None:
self.config = {"configurable": {"thread_id": thread_id}}
@pytest.mark.asyncio
async def test_first_acquire_succeeds_and_release_unblocks() -> None:
mw = BusyMutexMiddleware()
runtime = _Runtime("t1")
await mw.abefore_agent({}, runtime)
# Lock should now be held
lock = manager.lock_for("t1")
assert lock.locked()
await mw.aafter_agent({}, runtime)
assert not lock.locked()
@pytest.mark.asyncio
async def test_second_concurrent_acquire_raises_busy() -> None:
mw_a = BusyMutexMiddleware()
mw_b = BusyMutexMiddleware()
runtime = _Runtime("t-conflict")
await mw_a.abefore_agent({}, runtime)
with pytest.raises(BusyError) as excinfo:
await mw_b.abefore_agent({}, runtime)
assert excinfo.value.request_id == "t-conflict"
await mw_a.aafter_agent({}, runtime)
# After release, mw_b can acquire
await mw_b.abefore_agent({}, runtime)
await mw_b.aafter_agent({}, runtime)
@pytest.mark.asyncio
async def test_cancel_event_lifecycle() -> None:
mw = BusyMutexMiddleware()
runtime = _Runtime("t-cancel")
await mw.abefore_agent({}, runtime)
event = get_cancel_event("t-cancel")
assert not event.is_set()
request_cancel("t-cancel")
assert event.is_set()
# End of turn should reset
await mw.aafter_agent({}, runtime)
assert not event.is_set()
@pytest.mark.asyncio
async def test_no_thread_id_raises_when_required() -> None:
mw = BusyMutexMiddleware(require_thread_id=True)
runtime = _Runtime(None)
with pytest.raises(BusyError):
await mw.abefore_agent({}, runtime)
@pytest.mark.asyncio
async def test_no_thread_id_skipped_when_not_required() -> None:
mw = BusyMutexMiddleware(require_thread_id=False)
runtime = _Runtime(None)
await mw.abefore_agent({}, runtime)
await mw.aafter_agent({}, runtime)
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")

View file

@ -0,0 +1,107 @@
"""Tests for SurfSenseCompactionMiddleware: protected SystemMessage handling and content sanitization."""
from __future__ import annotations
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from app.agents.new_chat.middleware.compaction import (
PROTECTED_SYSTEM_PREFIXES,
_is_protected_system_message,
_sanitize_message_content,
)
pytestmark = pytest.mark.unit
class TestIsProtectedSystemMessage:
@pytest.mark.parametrize("prefix", PROTECTED_SYSTEM_PREFIXES)
def test_each_prefix_protected(self, prefix: str) -> None:
msg = SystemMessage(content=f"{prefix}\nbody\n</close>")
assert _is_protected_system_message(msg) is True
def test_unprotected_system_message(self) -> None:
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
def test_human_message_never_protected(self) -> None:
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...")
assert _is_protected_system_message(msg) is True
class TestSanitizeMessageContent:
def test_returns_same_message_when_content_present(self) -> None:
msg = AIMessage(content="hello")
assert _sanitize_message_content(msg) is msg
def test_replaces_none_with_empty_string(self) -> None:
# Pydantic blocks ``content=None`` at construction; the real
# crash happens when the streaming layer mutates ``content``
# after-the-fact. Replicate that by force-setting on a built
# message.
msg = AIMessage(
content="",
tool_calls=[{"name": "x", "args": {}, "id": "1"}],
)
# Bypass pydantic validation to simulate the LiteLLM/Bedrock case
object.__setattr__(msg, "content", None)
sanitized = _sanitize_message_content(msg)
assert sanitized.content == ""
class TestPartitionMessages:
"""Verify the partition override surfaces protected SystemMessages
into ``preserved_messages`` regardless of cutoff position.
"""
def _build_partitioner(self):
# Construct a thin shim — we can't easily instantiate the full
# SurfSenseCompactionMiddleware without a real model, but the
# override path needs ``_lc_helper`` to delegate to. We mock
# that with a simple slicing partitioner equivalent to the real one.
from app.agents.new_chat.middleware.compaction import (
SurfSenseCompactionMiddleware,
)
class _LcHelper:
@staticmethod
def _partition_messages(messages, cutoff):
return messages[:cutoff], messages[cutoff:]
class _Stub(SurfSenseCompactionMiddleware):
def __init__(self):
self._lc_helper = _LcHelper()
return _Stub()
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
partitioner = self._build_partitioner()
protected = SystemMessage(content="<priority_documents>\n...")
msgs = [
HumanMessage(content="old human"),
AIMessage(content="old ai"),
protected,
ToolMessage(content="tool 1", tool_call_id="t1"),
HumanMessage(content="new"),
]
# Cutoff = 4 means everything before index 4 should be summarized
to_summary, preserved = partitioner._partition_messages(msgs, 4)
assert protected not in to_summary
assert protected in preserved
# The non-protected old messages remain in to_summary
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
def test_unprotected_messages_unaffected(self) -> None:
partitioner = self._build_partitioner()
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
to_summary, preserved = partitioner._partition_messages(msgs, 2)
assert [m.content for m in to_summary] == ["a", "b"]
assert [m.content for m in preserved] == ["c"]

View file

@ -0,0 +1,107 @@
"""Tests for SpillToBackendEdit and SpillingContextEditingMiddleware."""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.agents.new_chat.middleware.context_editing import (
SpillToBackendEdit,
_build_spill_placeholder,
)
pytestmark = pytest.mark.unit
def _build_history(num_pairs: int = 6) -> list[Any]:
"""Build a long history of (AIMessage with tool_call, ToolMessage) pairs."""
msgs: list[Any] = [HumanMessage(content="please do many things")]
for i in range(num_pairs):
msgs.append(
AIMessage(
content="",
tool_calls=[
{"name": f"tool_{i}", "args": {"i": i}, "id": f"call-{i}"},
],
)
)
msgs.append(
ToolMessage(
content="x" * 5000,
tool_call_id=f"call-{i}",
name=f"tool_{i}",
id=f"tool-msg-{i}",
)
)
return msgs
def _approx_count(messages: list[Any]) -> int:
"""Trivial token counter: 1 token per 4 chars."""
total = 0
for msg in messages:
content = getattr(msg, "content", "")
if isinstance(content, str):
total += len(content) // 4
return total
class TestSpillEdit:
def test_below_trigger_does_nothing(self) -> None:
edit = SpillToBackendEdit(trigger=1_000_000, keep=2)
msgs = _build_history(3)
original_lengths = [len(getattr(m, "content", "")) for m in msgs]
edit.apply(msgs, count_tokens=_approx_count)
new_lengths = [len(getattr(m, "content", "")) for m in msgs]
assert original_lengths == new_lengths
assert edit.pending_spills == []
def test_above_trigger_clears_and_records(self) -> None:
edit = SpillToBackendEdit(trigger=100, keep=1, path_prefix="/tool_outputs")
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
# The most-recent ToolMessage (keep=1) should remain intact
tool_messages = [m for m in msgs if isinstance(m, ToolMessage)]
intact = tool_messages[-1]
assert intact.content.startswith("x") # untouched
# Earlier ToolMessages should now contain the placeholder text
cleared = [
m for m in tool_messages
if isinstance(m.content, str) and m.content.startswith("[cleared")
]
assert len(cleared) >= 1
# And the spill list should match
assert len(edit.pending_spills) == len(cleared)
def test_excluded_tools_not_cleared(self) -> None:
edit = SpillToBackendEdit(
trigger=100,
keep=0,
exclude_tools=("tool_0",),
)
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
first_tool = next(
m for m in msgs if isinstance(m, ToolMessage) and m.name == "tool_0"
)
# Excluded — untouched
assert first_tool.content.startswith("x")
def test_drain_clears_pending(self) -> None:
edit = SpillToBackendEdit(trigger=100, keep=1)
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
first_drain = edit.drain_pending()
assert len(first_drain) > 0
assert edit.drain_pending() == []
def test_placeholder_format(self) -> None:
path = "/tool_outputs/thread-1/tool-msg-0.txt"
text = _build_spill_placeholder(path)
assert path in text
assert "explore" in text # mentions the recovery agent

View file

@ -0,0 +1,132 @@
"""Tests for declarative dedup_key on ToolDefinition (Tier 2.3 migration)."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from langchain_core.tools import StructuredTool
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
pytestmark = pytest.mark.unit
def _make_tool(name: str, *, dedup_key=None, hitl_dedup_key=None):
metadata = {}
if dedup_key is not None:
metadata["dedup_key"] = dedup_key
if hitl_dedup_key is not None:
metadata["hitl"] = True
metadata["hitl_dedup_key"] = hitl_dedup_key
def _fn(**kwargs):
return "ok"
return StructuredTool.from_function(
func=_fn, name=name, description="x", metadata=metadata
)
def _msg(*calls: dict) -> AIMessage:
return AIMessage(content="", tool_calls=list(calls))
class _Runtime:
pass
def test_callable_dedup_key_takes_priority() -> None:
tool = _make_tool(
"create_doc",
dedup_key=lambda args: f"{args.get('parent_id')}::{args.get('title')}",
)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
new_calls = out["messages"][0].tool_calls
assert len(new_calls) == 2 # one duplicate dropped
assert {c["id"] for c in new_calls} == {"1", "3"}
def test_string_hitl_dedup_key_still_works() -> None:
tool = _make_tool("send_x", hitl_dedup_key="subject")
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{"name": "send_x", "args": {"subject": "Hello"}, "id": "1"},
{"name": "send_x", "args": {"subject": "hello"}, "id": "2"}, # case
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
assert len(out["messages"][0].tool_calls) == 1
def test_no_agent_tools_means_no_dedup() -> None:
"""After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS``
map, dedup is purely declarative no resolvers means no dedup runs.
Coverage for the previously hardcoded native HITL tools now lives on
each :class:`ToolDefinition.dedup_key` in
:mod:`app.agents.new_chat.tools.registry`, which is wired through to
``tool.metadata`` by :func:`build_tools`.
"""
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {
"messages": [
_msg(
{"name": "create_notion_page", "args": {"title": "X"}, "id": "1"},
{"name": "create_notion_page", "args": {"title": "x"}, "id": "2"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None
def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
"""Smoke-check the wiring path that replaced the legacy native map.
``ToolDefinition.dedup_key`` set in the registry must be copied onto
the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware`
can pick it up at agent build time.
"""
from app.agents.new_chat.tools.registry import (
BUILTIN_TOOLS,
wrap_dedup_key_by_arg_name,
)
notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"]
assert notion_tool_defs, "registry should still expose create_notion_page"
tool_def = notion_tool_defs[0]
assert tool_def.dedup_key is not None
# Same wrapping helper used in the registry — sanity check identity
sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"})
assert sample == "plan"
def test_unknown_tool_passes_through() -> None:
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {
"messages": [
_msg(
{"name": "anything_else", "args": {"x": 1}, "id": "1"},
{"name": "anything_else", "args": {"x": 1}, "id": "2"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None # no dedup configured -> kept

View file

@ -0,0 +1,128 @@
"""Lock in the default-allow layering used by ``chat_deepagent``.
The agent factory wires ``PermissionMiddleware`` with three rulesets,
earliest -> latest:
1. ``surfsense_defaults`` (single ``allow */*`` rule)
2. ``connector_synthesized`` (deny rules for tools whose required
connector is missing)
3. (future) user-defined rules from the Agent Permissions UI
Without #1 every read-only built-in (``ls``, ``read_file``, ``grep``,
``glob``, ``web_search`` ) defaulted to ``ask`` because
``permissions.evaluate`` returns ``ask`` when no rule matches. That
caused two production-painful behaviors:
* Resume payloads with a prior reject decision bled into innocent
read-only tool calls, raising ``RejectedError("ls")``.
* Mutating connector tools got *double* prompted once via the
middleware ``ask`` and again via the per-tool ``interrupt()`` in
``app.agents.new_chat.tools.hitl``.
These tests pin the layering so a refactor that drops the default
ruleset fails loud.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
pytestmark = pytest.mark.unit
def _layered_rulesets(connector_denies: list[Rule]) -> list[Ruleset]:
"""Replicate ``chat_deepagent`` layering for the test."""
return [
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
Ruleset(rules=connector_denies, origin="connector_synthesized"),
]
class TestReadOnlyToolsAllowed:
"""Read-only built-ins must NOT default to ask."""
@pytest.mark.parametrize(
"tool_name",
[
"ls",
"read_file",
"grep",
"glob",
"web_search",
"scrape_webpage",
"search_surfsense_docs",
"get_connected_accounts",
"write_todos",
"task",
"_noop",
"invalid",
"update_memory",
],
)
def test_default_allow_covers_safe_builtin(self, tool_name: str) -> None:
rulesets = _layered_rulesets(connector_denies=[])
rules = evaluate_many(tool_name, [tool_name], *rulesets)
assert aggregate_action(rules) == "allow"
class TestConnectorDenyOverridesDefaultAllow:
"""Connector-synthesized denies must beat the default-allow rule."""
def test_missing_connector_tool_is_denied(self) -> None:
rulesets = _layered_rulesets(
connector_denies=[
Rule(permission="linear_create_issue", pattern="*", action="deny")
]
)
rules = evaluate_many(
"linear_create_issue", ["linear_create_issue"], *rulesets
)
assert aggregate_action(rules) == "deny"
def test_default_allow_still_applies_to_other_tools(self) -> None:
"""A deny rule for one tool must not bleed onto unrelated calls."""
rulesets = _layered_rulesets(
connector_denies=[
Rule(permission="linear_create_issue", pattern="*", action="deny")
]
)
rules = evaluate_many("ls", ["ls"], *rulesets)
assert aggregate_action(rules) == "allow"
class TestUserRuleOverridesDefault:
"""User rules layered last must override the default-allow rule."""
def test_user_ask_overrides_default_allow(self) -> None:
defaults = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
user_ruleset = Ruleset(
rules=[Rule(permission="ls", pattern="*", action="ask")],
origin="user",
)
rules = evaluate_many("ls", ["ls"], defaults, user_ruleset)
assert aggregate_action(rules) == "ask"
def test_user_deny_overrides_default_allow(self) -> None:
defaults = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
user_ruleset = Ruleset(
rules=[Rule(permission="send_*", pattern="*", action="deny")],
origin="user",
)
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
assert aggregate_action(rules) == "deny"

View file

@ -0,0 +1,99 @@
"""Tests for DoomLoopMiddleware signature equality detection."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature
pytestmark = pytest.mark.unit
def test_signature_is_stable_for_identical_args() -> None:
a = _signature("search", {"q": "hello", "n": 10})
b = _signature("search", {"n": 10, "q": "hello"})
assert a == b
def test_signature_changes_with_args() -> None:
a = _signature("search", {"q": "hello"})
b = _signature("search", {"q": "world"})
assert a != b
def test_signature_changes_with_name() -> None:
a = _signature("search", {"q": "x"})
b = _signature("read", {"q": "x"})
assert a != b
class _FakeRuntime:
def __init__(self, thread_id: str | None = "thread-1") -> None:
self.config = {"configurable": {"thread_id": thread_id}}
def _msg_calling(name: str, args: dict, call_id: str) -> AIMessage:
return AIMessage(
content="",
tool_calls=[{"name": name, "args": args, "id": call_id}],
)
def test_threshold_triggers_after_n_identical_calls() -> None:
mw = DoomLoopMiddleware(threshold=3)
runtime = _FakeRuntime()
# First two calls — under threshold
for i in range(2):
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, f"call-{i}")]},
runtime,
)
assert out is None
# Third identical call should trigger ``langgraph.types.interrupt``.
# In a unit-test context (no runnable graph), ``interrupt`` raises
# ``RuntimeError`` because ``get_config`` has nothing to bind to —
# we accept that as proof the interrupt path was taken (the
# alternative would be no exception, which would mean the loop
# detection never fired).
with pytest.raises(Exception) as excinfo:
mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, "call-3")]},
runtime,
)
name = type(excinfo.value).__name__.lower()
assert (
"interrupt" in name
or "runtimeerror" in name
), f"Expected an interrupt-style exception, got {name}"
def test_does_not_trigger_when_args_differ() -> None:
mw = DoomLoopMiddleware(threshold=2)
runtime = _FakeRuntime()
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
)
assert out is None
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
)
assert out is None
def test_separate_threads_have_independent_windows() -> None:
mw = DoomLoopMiddleware(threshold=2)
rt_a = _FakeRuntime(thread_id="A")
rt_b = _FakeRuntime(thread_id="B")
mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_a)
# thread B should NOT count thread A's call
out = mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_b)
assert out is None # not yet at threshold for B
def test_invalid_threshold_rejected() -> None:
with pytest.raises(ValueError):
DoomLoopMiddleware(threshold=1)

View file

@ -0,0 +1,120 @@
"""Tests for the agent feature-flag system."""
from __future__ import annotations
import pytest
from app.agents.new_chat.feature_flags import (
AgentFeatureFlags,
reload_for_tests,
)
pytestmark = pytest.mark.unit
def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
for name in [
"SURFSENSE_DISABLE_NEW_AGENT_STACK",
"SURFSENSE_ENABLE_CONTEXT_EDITING",
"SURFSENSE_ENABLE_COMPACTION_V2",
"SURFSENSE_ENABLE_RETRY_AFTER",
"SURFSENSE_ENABLE_MODEL_FALLBACK",
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
"SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
"SURFSENSE_ENABLE_DOOM_LOOP",
"SURFSENSE_ENABLE_PERMISSION",
"SURFSENSE_ENABLE_BUSY_MUTEX",
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"SURFSENSE_ENABLE_SKILLS",
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
]:
monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False
def test_master_kill_switch_overrides_individual_flags(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_DISABLE_NEW_AGENT_STACK", "true")
monkeypatch.setenv("SURFSENSE_ENABLE_CONTEXT_EDITING", "true")
monkeypatch.setenv("SURFSENSE_ENABLE_PERMISSION", "true")
flags = reload_for_tests()
assert flags.disable_new_agent_stack is True
assert flags.enable_context_editing is False
assert flags.enable_permission is False
assert flags.any_new_middleware_enabled() is False
@pytest.mark.parametrize("truthy", ["1", "true", "TRUE", "yes", "on"])
def test_individual_flags_truthy_values(
monkeypatch: pytest.MonkeyPatch, truthy: str
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", truthy)
flags = reload_for_tests()
assert flags.enable_retry_after is True
assert flags.any_new_middleware_enabled() is True
@pytest.mark.parametrize("falsy", ["0", "false", "no", "off", "", "garbage"])
def test_individual_flags_falsy_values(
monkeypatch: pytest.MonkeyPatch, falsy: str
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", falsy)
flags = reload_for_tests()
assert flags.enable_retry_after is False
def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flag_to_env = {
"enable_context_editing": "SURFSENSE_ENABLE_CONTEXT_EDITING",
"enable_compaction_v2": "SURFSENSE_ENABLE_COMPACTION_V2",
"enable_retry_after": "SURFSENSE_ENABLE_RETRY_AFTER",
"enable_model_fallback": "SURFSENSE_ENABLE_MODEL_FALLBACK",
"enable_model_call_limit": "SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
"enable_tool_call_limit": "SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
"enable_tool_call_repair": "SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
"enable_doom_loop": "SURFSENSE_ENABLE_DOOM_LOOP",
"enable_permission": "SURFSENSE_ENABLE_PERMISSION",
"enable_busy_mutex": "SURFSENSE_ENABLE_BUSY_MUTEX",
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true")
flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False

View file

@ -0,0 +1,119 @@
"""Tests for NoopInjectionMiddleware provider-compat logic."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.noop_injection import (
NOOP_TOOL_NAME,
NoopInjectionMiddleware,
_last_ai_has_tool_calls,
_provider_needs_noop,
)
pytestmark = pytest.mark.unit
class _LiteLLMModel:
def _get_ls_params(self):
return {"ls_provider": "litellm"}
class _BedrockModel:
def _get_ls_params(self):
return {"ls_provider": "bedrock"}
class _OpenAIModel:
def _get_ls_params(self):
return {"ls_provider": "openai"}
class _ChatLiteLLM: # name-only fallback
pass
class TestProviderDetection:
def test_litellm(self) -> None:
assert _provider_needs_noop(_LiteLLMModel()) is True
def test_bedrock(self) -> None:
assert _provider_needs_noop(_BedrockModel()) is True
def test_openai_does_not_need(self) -> None:
assert _provider_needs_noop(_OpenAIModel()) is False
def test_class_name_fallback(self) -> None:
assert _provider_needs_noop(_ChatLiteLLM()) is True
class TestHistoryDetection:
def test_last_ai_has_tool_calls(self) -> None:
msgs = [
HumanMessage(content="hi"),
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
]
assert _last_ai_has_tool_calls(msgs) is True
def test_last_ai_no_tool_calls(self) -> None:
msgs = [
HumanMessage(content="hi"),
AIMessage(content="hello"),
]
assert _last_ai_has_tool_calls(msgs) is False
def test_no_ai_in_history(self) -> None:
assert _last_ai_has_tool_calls([HumanMessage(content="hi")]) is False
class _FakeRequest:
def __init__(self, *, tools, messages, model) -> None:
self.tools = tools
self.messages = messages
self.model = model
def override(self, *, tools):
return _FakeRequest(tools=tools, messages=self.messages, model=self.model)
class TestShouldInject:
def test_injects_when_all_conditions_met(self) -> None:
mw = NoopInjectionMiddleware()
msgs = [
HumanMessage(content="hi"),
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
]
req = _FakeRequest(tools=[], messages=msgs, model=_LiteLLMModel())
assert mw._should_inject(req) is True
def test_skips_when_tools_present(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[object()],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
model=_LiteLLMModel(),
)
assert mw._should_inject(req) is False
def test_skips_when_no_history_tool_calls(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[],
messages=[HumanMessage(content="hi")],
model=_LiteLLMModel(),
)
assert mw._should_inject(req) is False
def test_skips_for_openai(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
model=_OpenAIModel(),
)
assert mw._should_inject(req) is False
def test_noop_tool_name_is_underscore_noop() -> None:
assert NOOP_TOOL_NAME == "_noop"

View file

@ -0,0 +1,195 @@
"""Tests for the OtelSpanMiddleware adapter (Tier 3b)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.middleware.otel_span import (
OtelSpanMiddleware,
_annotate_model_response,
_annotate_tool_result,
_resolve_input_size,
_resolve_model_attrs,
_resolve_tool_name,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _disable_otel(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
from app.observability import otel as ot
ot.reload_for_tests()
yield
ot.reload_for_tests()
class TestResolveModelAttrs:
def test_extracts_model_name_and_provider(self) -> None:
request = MagicMock()
request.model = MagicMock(spec=["model_name", "provider"])
request.model.model_name = "gpt-4o-mini"
request.model.provider = "openai"
assert _resolve_model_attrs(request) == ("gpt-4o-mini", "openai")
def test_handles_missing_model(self) -> None:
request = MagicMock()
request.model = None
assert _resolve_model_attrs(request) == (None, None)
def test_falls_back_through_attribute_chain(self) -> None:
request = MagicMock()
request.model = MagicMock(spec=["model_id", "_llm_type"])
request.model.model_id = "claude-3-5-sonnet"
request.model._llm_type = "anthropic-chat"
model_id, provider = _resolve_model_attrs(request)
assert model_id == "claude-3-5-sonnet"
assert provider == "anthropic-chat"
class TestResolveToolName:
def test_prefers_request_tool_name(self) -> None:
request = MagicMock()
request.tool = MagicMock(name="ToolStub")
request.tool.name = "scrape_webpage"
assert _resolve_tool_name(request) == "scrape_webpage"
def test_falls_back_to_tool_call_name(self) -> None:
request = MagicMock()
request.tool = None
request.tool_call = {"name": "web_search", "args": {}}
assert _resolve_tool_name(request) == "web_search"
def test_unknown_when_nothing_resolves(self) -> None:
request = MagicMock()
request.tool = None
request.tool_call = {}
assert _resolve_tool_name(request) == "unknown"
class TestResolveInputSize:
def test_returns_repr_length_of_args(self) -> None:
request = MagicMock()
request.tool_call = {"args": {"query": "hello world"}}
size = _resolve_input_size(request)
assert isinstance(size, int)
assert size > 0
def test_handles_no_tool_call(self) -> None:
request = MagicMock()
request.tool_call = None
assert _resolve_input_size(request) is None
class TestAnnotateModelResponse:
def test_attaches_token_counts_when_present(self) -> None:
sp = MagicMock()
msg = AIMessage(
content="hello",
usage_metadata={
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
},
)
_annotate_model_response(sp, msg)
sp.set_attribute.assert_any_call("tokens.prompt", 100)
sp.set_attribute.assert_any_call("tokens.completion", 50)
sp.set_attribute.assert_any_call("tokens.total", 150)
def test_handles_response_with_no_metadata(self) -> None:
sp = MagicMock()
msg = AIMessage(content="hello")
# Should not raise even when usage_metadata is missing
_annotate_model_response(sp, msg)
class TestAnnotateToolResult:
def test_records_size_and_status(self) -> None:
sp = MagicMock()
result = ToolMessage(
content="result text",
tool_call_id="abc",
status="success",
)
_annotate_tool_result(sp, result)
sp.set_attribute.assert_any_call("tool.output.size", len("result text"))
sp.set_attribute.assert_any_call("tool.status", "success")
def test_marks_errors(self) -> None:
sp = MagicMock()
result = ToolMessage(
content="oops",
tool_call_id="abc",
additional_kwargs={"error": {"code": "x"}},
)
_annotate_tool_result(sp, result)
sp.set_attribute.assert_any_call("tool.error", True)
@pytest.mark.asyncio
class TestMiddlewareIntegration:
async def test_awrap_model_call_passes_through_when_disabled(self) -> None:
mw = OtelSpanMiddleware()
called: dict[str, Any] = {}
async def handler(req):
called["req"] = req
return AIMessage(content="ok")
request = MagicMock()
result = await mw.awrap_model_call(request, handler)
assert called["req"] is request
assert isinstance(result, AIMessage)
assert result.content == "ok"
async def test_awrap_tool_call_passes_through_when_disabled(self) -> None:
mw = OtelSpanMiddleware()
async def handler(req):
return ToolMessage(content="result", tool_call_id="abc")
request = MagicMock()
result = await mw.awrap_tool_call(request, handler)
assert isinstance(result, ToolMessage)
assert result.content == "result"
async def test_awrap_model_call_propagates_exceptions(self) -> None:
mw = OtelSpanMiddleware()
async def handler(req):
raise ValueError("boom")
with pytest.raises(ValueError):
await mw.awrap_model_call(MagicMock(), handler)
async def test_with_otel_enabled_does_not_alter_result(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
from app.observability import otel as ot
ot.reload_for_tests()
try:
mw = OtelSpanMiddleware()
async def handler(req):
return AIMessage(content="enabled")
request = MagicMock()
request.model = MagicMock()
request.model.model_name = "gpt-4o"
request.model.provider = "openai"
result = await mw.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert result.content == "enabled"
finally:
ot.reload_for_tests()

View file

@ -0,0 +1,116 @@
"""Tests for PermissionMiddleware end-to-end behavior."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.permissions import Rule, Ruleset
pytestmark = pytest.mark.unit
class _FakeRuntime:
config: dict = {"configurable": {"thread_id": "test"}}
def _msg(*tool_calls: dict) -> AIMessage:
return AIMessage(content="", tool_calls=list(tool_calls))
class TestAllow:
def test_passthrough_when_allow(self) -> None:
rs = Ruleset(rules=[Rule("send_email", "*", "allow")])
mw = PermissionMiddleware(rulesets=[rs])
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is None # no change
class TestDeny:
def test_replaces_with_deny_tool_message(self) -> None:
rs = Ruleset(rules=[Rule("send_email", "*", "deny")])
mw = PermissionMiddleware(rulesets=[rs])
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is not None
msgs = out["messages"]
# Find the deny ToolMessage
deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
assert len(deny_msgs) == 1
assert deny_msgs[0].status == "error"
assert "permission_denied" in str(deny_msgs[0].additional_kwargs)
# AIMessage's tool_calls should now be empty (denied call removed)
ai_msg = next(m for m in msgs if isinstance(m, AIMessage))
assert ai_msg.tool_calls == []
def test_mixed_allow_deny(self) -> None:
rs = Ruleset(
rules=[
Rule("send_email", "*", "deny"),
Rule("read", "*", "allow"),
]
)
mw = PermissionMiddleware(rulesets=[rs])
state = {
"messages": [
_msg(
{"name": "send_email", "args": {}, "id": "1"},
{"name": "read", "args": {}, "id": "2"},
)
]
}
out = mw.after_model(state, _FakeRuntime())
assert out is not None
ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage))
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "read"
class TestAsk:
def test_reject_without_feedback_raises(self) -> None:
# Default: nothing matches -> ask
rs = Ruleset(rules=[])
mw = PermissionMiddleware(rulesets=[rs])
# Bypass real interrupt — patch the helper
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())
def test_reject_with_feedback_raises_corrected(self) -> None:
rs = Ruleset(rules=[])
mw = PermissionMiddleware(rulesets=[rs])
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
"decision_type": "reject",
"feedback": "use a different subject line",
}
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(CorrectedError) as excinfo:
mw.after_model(state, _FakeRuntime())
assert excinfo.value.feedback == "use a different subject line"
def test_once_proceeds_without_persisting(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
# No state change because all calls kept
assert out is None
# No new rule persisted
assert mw._runtime_ruleset.rules == []
def test_always_persists_runtime_rule(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is None # call kept
# Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(
r.permission == "send_email" for r in new_rules
)

View file

@ -0,0 +1,111 @@
"""Tests for the wildcard matcher and rule evaluator (opencode evaluate.ts parity)."""
from __future__ import annotations
import pytest
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate,
evaluate_many,
wildcard_match,
)
pytestmark = pytest.mark.unit
class TestWildcardMatch:
@pytest.mark.parametrize(
"value,pattern,expected",
[
("edit", "edit", True),
("edit", "*", True),
("read", "edit", False),
("/documents/secrets/x", "/documents/secrets/**", True),
# Single-segment glob: '*' does not cross '/'
("/documents/secrets/x", "/documents/*/x", True),
("/documents/foo/bar/x", "/documents/*/x", False),
("/documents/foo/x", "/documents/*/x", True),
("linear_create", "linear_*", True),
("notion_create", "linear_*", False),
# ':' is not a separator, so '*' matches it
("mcp:notion:create_page", "mcp:*", True),
("mcp:notion:create_page", "mcp:**", True),
# But '/' IS a separator
("foo/bar", "foo/*", True),
("foo/bar/baz", "foo/*", False),
],
)
def test_match(self, value: str, pattern: str, expected: bool) -> None:
assert wildcard_match(value, pattern) is expected
class TestEvaluate:
def test_default_action_is_ask(self) -> None:
rule = evaluate("edit", "/foo/bar")
assert rule.action == "ask"
assert rule.permission == "edit"
def test_last_match_wins(self) -> None:
rs = Ruleset(
rules=[
Rule("edit", "*", "allow"),
Rule("edit", "/secrets/**", "deny"),
]
)
# Second rule (deny) is more specific AND specified later
assert evaluate("edit", "/secrets/x", rs).action == "deny"
# First rule (allow) covers the rest
assert evaluate("edit", "/public/x", rs).action == "allow"
def test_layered_rulesets_later_overrides_earlier(self) -> None:
defaults = Ruleset(rules=[Rule("edit", "*", "ask")], origin="defaults")
space = Ruleset(rules=[Rule("edit", "*", "allow")], origin="space")
thread = Ruleset(rules=[Rule("edit", "*", "deny")], origin="thread")
# All three layered: thread wins
assert evaluate("edit", "x", defaults, space, thread).action == "deny"
# Without thread: space wins
assert evaluate("edit", "x", defaults, space).action == "allow"
def test_permission_wildcard(self) -> None:
rs = Ruleset(rules=[Rule("linear_*", "*", "allow")])
assert evaluate("linear_create_issue", "x", rs).action == "allow"
assert evaluate("notion_create", "x", rs).action == "ask"
def test_pattern_wildcard(self) -> None:
rs = Ruleset(rules=[Rule("edit", "/documents/secrets/**", "deny")])
assert evaluate("edit", "/documents/secrets/foo", rs).action == "deny"
assert evaluate("edit", "/documents/public/foo", rs).action == "ask"
def test_evaluate_many(self) -> None:
rs = Ruleset(
rules=[
Rule("edit", "*", "allow"),
Rule("edit", "/secrets/*", "deny"),
]
)
results = evaluate_many("edit", ["/public/x", "/secrets/y"], rs)
assert [r.action for r in results] == ["allow", "deny"]
class TestAggregateAction:
def test_any_deny_means_deny(self) -> None:
rules = [
Rule("a", "*", "allow"),
Rule("a", "*", "deny"),
Rule("a", "*", "ask"),
]
assert aggregate_action(rules) == "deny"
def test_any_ask_means_ask_when_no_deny(self) -> None:
rules = [Rule("a", "*", "allow"), Rule("a", "*", "ask")]
assert aggregate_action(rules) == "ask"
def test_all_allow_means_allow(self) -> None:
rules = [Rule("a", "*", "allow"), Rule("a", "*", "allow")]
assert aggregate_action(rules) == "allow"
def test_empty_means_ask(self) -> None:
assert aggregate_action([]) == "ask"

View file

@ -0,0 +1,187 @@
"""Unit tests for the SurfSense plugin entry-point loader (Tier 6)."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from langchain.agents.middleware import AgentMiddleware
from app.agents.new_chat.plugin_loader import (
PLUGIN_ENTRY_POINT_GROUP,
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.plugins.year_substituter import (
_YearSubstituterMiddleware,
make_middleware as year_substituter_factory,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _DummyMiddleware(AgentMiddleware):
"""Trivial middleware used as the success-path return value."""
tools = ()
def _ctx() -> PluginContext:
return PluginContext.build(
search_space_id=1,
user_id="u",
thread_visibility="PRIVATE", # type: ignore[arg-type]
llm=MagicMock(),
)
class _FakeEntryPoint:
"""Stand-in for ``importlib.metadata.EntryPoint``."""
def __init__(self, name: str, factory) -> None:
self.name = name
self._factory = factory
def load(self):
return self._factory
# ---------------------------------------------------------------------------
# Loader behaviour
# ---------------------------------------------------------------------------
class TestPluginLoaderBasics:
def test_returns_empty_when_allowlist_is_empty(self) -> None:
assert load_plugin_middlewares(_ctx(), allowed_plugin_names=[]) == []
def test_skips_non_allowlisted_plugin(self) -> None:
called = []
def factory(_): # would be an obvious bug if called
called.append(True)
return _DummyMiddleware()
ep = _FakeEntryPoint("dangerous_plugin", factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
assert result == []
assert not called
def test_loads_allowlisted_plugin(self) -> None:
ep = _FakeEntryPoint("year_substituter", year_substituter_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"year_substituter"}
)
assert len(result) == 1
assert isinstance(result[0], _YearSubstituterMiddleware)
class TestPluginLoaderIsolation:
def test_factory_exception_is_isolated(self) -> None:
def crashing_factory(_):
raise RuntimeError("boom")
ep = _FakeEntryPoint("buggy", crashing_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"})
assert result == [] # construction continued without the plugin
def test_non_middleware_return_is_rejected(self) -> None:
def bad_factory(_):
return "not a middleware"
ep = _FakeEntryPoint("liar", bad_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"})
assert result == []
def test_load_phase_exception_is_isolated(self) -> None:
class _BrokenEP:
name = "broken"
def load(self):
raise ImportError("cannot import")
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[_BrokenEP()],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"})
assert result == []
def test_one_failure_does_not_block_others(self) -> None:
"""Two plugins; one crashes during factory; the other still loads."""
def crashing_factory(_):
raise RuntimeError("boom")
eps = [
_FakeEntryPoint("crashing", crashing_factory),
_FakeEntryPoint("ok", year_substituter_factory),
]
with patch(
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
):
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"crashing", "ok"}
)
assert len(result) == 1
assert isinstance(result[0], _YearSubstituterMiddleware)
class TestAllowlistEnv:
def test_empty_env_returns_empty_set(self, monkeypatch) -> None:
monkeypatch.delenv("SURFSENSE_ALLOWED_PLUGINS", raising=False)
assert load_allowed_plugin_names_from_env() == set()
def test_parses_comma_separated_value(self, monkeypatch) -> None:
monkeypatch.setenv(
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
)
assert load_allowed_plugin_names_from_env() == {
"year_substituter",
"noisy",
}
class TestPluginContext:
def test_build_includes_required_fields(self) -> None:
llm = MagicMock()
ctx = PluginContext.build(
search_space_id=42,
user_id="user-1",
thread_visibility="PRIVATE", # type: ignore[arg-type]
llm=llm,
)
assert ctx["search_space_id"] == 42
assert ctx["user_id"] == "user-1"
assert ctx["llm"] is llm
def test_does_not_carry_secrets_or_db_session(self) -> None:
ctx = _ctx()
# If a future change tries to add these keys, this test will fail loudly.
for forbidden in ("api_key", "secret", "db_session", "session"):
assert forbidden not in ctx
class TestEntryPointGroup:
def test_group_name_matches_pyproject_convention(self) -> None:
# Plugins register under `surfsense.plugins`; this is part of our
# public contract for plugin authors.
assert PLUGIN_ENTRY_POINT_GROUP == "surfsense.plugins"

View file

@ -0,0 +1,107 @@
"""Tests for RetryAfterMiddleware Retry-After parsing and retry decision logic."""
from __future__ import annotations
import pytest
from app.agents.new_chat.middleware.retry_after import (
RetryAfterMiddleware,
_extract_retry_after_seconds,
_is_non_retryable,
)
pytestmark = pytest.mark.unit
class _FakeResponse:
def __init__(self, headers: dict[str, str]) -> None:
self.headers = headers
class _FakeRateLimit(Exception):
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
super().__init__(msg)
if headers is not None:
self.response = _FakeResponse(headers)
class TestExtractRetryAfter:
def test_seconds_header(self) -> None:
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
assert _extract_retry_after_seconds(exc) == 30.0
def test_milliseconds_header_overrides_seconds(self) -> None:
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
assert _extract_retry_after_seconds(exc) == 1.5
def test_case_insensitive(self) -> None:
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
assert _extract_retry_after_seconds(exc) == 12.0
def test_falls_back_to_message_regex(self) -> None:
exc = Exception("Please retry after 7 seconds")
assert _extract_retry_after_seconds(exc) == 7.0
def test_returns_none_when_no_hint(self) -> None:
exc = Exception("oops")
assert _extract_retry_after_seconds(exc) is None
def test_handles_missing_headers_attr(self) -> None:
exc = ValueError("no headers")
assert _extract_retry_after_seconds(exc) is None
class TestIsNonRetryable:
@pytest.mark.parametrize(
"name",
["ContextWindowExceededError", "AuthenticationError", "InvalidRequestError"],
)
def test_non_retryable_classes(self, name: str) -> None:
cls = type(name, (Exception,), {})
assert _is_non_retryable(cls("x")) is True
def test_generic_exception_is_retryable(self) -> None:
assert _is_non_retryable(RuntimeError("transient")) is False
class TestDelayCalculation:
def test_takes_max_of_backoff_and_header(self) -> None:
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
exc = _FakeRateLimit("rl", {"retry-after": "10"})
delay = mw._delay_for_attempt(0, exc)
assert delay == pytest.approx(10.0)
def test_uses_backoff_when_no_header(self) -> None:
mw = RetryAfterMiddleware(
max_retries=3, initial_delay=2.0, backoff_factor=2.0, jitter=False
)
delay = mw._delay_for_attempt(2, RuntimeError("transient"))
# 2 * 2^2 = 8
assert delay == pytest.approx(8.0)
def test_caps_at_max_delay(self) -> None:
mw = RetryAfterMiddleware(
max_retries=3,
initial_delay=10.0,
backoff_factor=10.0,
max_delay=15.0,
jitter=False,
)
delay = mw._delay_for_attempt(5, RuntimeError("x"))
assert delay <= 15.0
class TestShouldRetry:
def test_default_retries_generic(self) -> None:
mw = RetryAfterMiddleware()
assert mw._should_retry(RuntimeError("transient")) is True
def test_default_skips_non_retryable(self) -> None:
mw = RetryAfterMiddleware()
cls = type("ContextWindowExceededError", (Exception,), {})
assert mw._should_retry(cls("too big")) is False
def test_custom_retry_on(self) -> None:
mw = RetryAfterMiddleware(retry_on=lambda exc: isinstance(exc, ValueError))
assert mw._should_retry(ValueError()) is True
assert mw._should_retry(KeyError()) is False

View file

@ -0,0 +1,242 @@
"""Tests for the skills backends used by SurfSense's SkillsMiddleware."""
from __future__ import annotations
import asyncio
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.skills_backends import (
SKILLS_BUILTIN_PREFIX,
SKILLS_SPACE_PREFIX,
BuiltinSkillsBackend,
SearchSpaceSkillsBackend,
build_skills_backend_factory,
default_skills_sources,
)
@pytest.fixture
def skills_root(tmp_path: Path) -> Path:
"""Build a small synthetic skill-tree used by the tests."""
root = tmp_path / "skills"
(root / "alpha").mkdir(parents=True)
(root / "alpha" / "SKILL.md").write_text(
"---\nname: alpha\ndescription: alpha skill\n---\n# Alpha\n"
)
(root / "beta").mkdir(parents=True)
(root / "beta" / "SKILL.md").write_text(
"---\nname: beta\ndescription: beta skill\n---\n# Beta\n"
)
(root / "_orphan_file.md").write_text("not a skill, just a stray file")
return root
class TestBuiltinSkillsBackendListing:
def test_lists_skill_directories_at_root(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = backend.ls_info("/")
names = {info["path"] for info in infos}
assert "/alpha" in names
assert "/beta" in names
assert "/_orphan_file.md" in names
for info in infos:
if info["path"] in {"/alpha", "/beta"}:
assert info["is_dir"] is True
def test_lists_skill_md_under_skill_directory(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = backend.ls_info("/alpha")
paths = {info["path"] for info in infos}
assert paths == {"/alpha/SKILL.md"}
assert infos[0]["is_dir"] is False
assert infos[0]["size"] > 0
def test_returns_empty_for_missing_path(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
assert backend.ls_info("/nonexistent") == []
def test_returns_empty_when_root_missing(self, tmp_path: Path) -> None:
backend = BuiltinSkillsBackend(tmp_path / "definitely-missing")
assert backend.ls_info("/") == []
assert backend.download_files(["/x/SKILL.md"])[0].error == "file_not_found"
def test_refuses_path_traversal(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
assert backend.ls_info("/../../../etc") == []
responses = backend.download_files(["/../../../etc/passwd"])
assert responses[0].error == "invalid_path"
class TestBuiltinSkillsBackendDownload:
def test_downloads_skill_md_content(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha/SKILL.md", "/beta/SKILL.md"])
assert len(responses) == 2
assert responses[0].path == "/alpha/SKILL.md"
assert responses[0].content is not None
assert b"name: alpha" in responses[0].content
assert responses[1].error is None
def test_marks_directory_as_is_directory_error(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha"])
assert responses[0].error == "is_directory"
def test_marks_missing_file_as_file_not_found(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha/missing.md"])
assert responses[0].error == "file_not_found"
assert responses[0].content is None
def test_response_path_matches_input_for_correlation(
self, skills_root: Path
) -> None:
backend = BuiltinSkillsBackend(skills_root)
inputs = ["/alpha/SKILL.md", "/missing.md", "/beta/SKILL.md"]
responses = backend.download_files(inputs)
assert [r.path for r in responses] == inputs
class TestBuiltinSkillsBackendIntegration:
"""Mirror the call sequence the SkillsMiddleware actually uses."""
def test_skills_middleware_call_pattern(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = asyncio.run(backend.als_info("/"))
skill_dirs = [i["path"] for i in infos if i.get("is_dir")]
assert sorted(skill_dirs) == ["/alpha", "/beta"]
skill_md_paths = [f"{p}/SKILL.md" for p in skill_dirs]
responses = asyncio.run(backend.adownload_files(skill_md_paths))
assert all(r.error is None for r in responses)
assert all(r.content is not None for r in responses)
class TestBundledSkills:
def test_default_root_resolves_to_repo_skills_dir(self) -> None:
backend = BuiltinSkillsBackend()
assert backend.root.name == "builtin"
assert backend.root.parent.name == "skills"
def test_bundled_starter_skills_are_present(self) -> None:
backend = BuiltinSkillsBackend()
infos = backend.ls_info("/")
names = {info["path"].lstrip("/") for info in infos if info.get("is_dir")}
# Five starter skills required by the Tier 4 plan.
for required in (
"kb-research",
"report-writing",
"meeting-prep",
"slack-summary",
"email-drafting",
):
assert required in names, f"missing starter skill: {required}"
def test_each_starter_skill_has_valid_skill_md(self) -> None:
backend = BuiltinSkillsBackend()
infos = backend.ls_info("/")
skill_dirs = [info["path"] for info in infos if info.get("is_dir")]
for skill_dir in skill_dirs:
md_path = f"{skill_dir}/SKILL.md"
response = backend.download_files([md_path])[0]
assert response.error is None, f"missing SKILL.md in {skill_dir}"
content = response.content.decode("utf-8").replace("\r\n", "\n")
assert content.startswith("---\n"), f"missing frontmatter in {skill_dir}"
assert "\nname:" in content
assert "\ndescription:" in content
class _FakeKBBackend:
"""Stand-in for :class:`KBPostgresBackend` with the two methods we need."""
def __init__(self, listing: list[dict], file_contents: dict[str, bytes]) -> None:
self._listing = listing
self._file_contents = file_contents
self.last_ls_path: str | None = None
self.last_download_paths: list[str] | None = None
async def als_info(self, path: str):
self.last_ls_path = path
return self._listing
async def adownload_files(self, paths):
from deepagents.backends.protocol import FileDownloadResponse
self.last_download_paths = list(paths)
out: list[FileDownloadResponse] = []
for p in paths:
content = self._file_contents.get(p)
if content is None:
out.append(FileDownloadResponse(path=p, error="file_not_found"))
else:
out.append(FileDownloadResponse(path=p, content=content))
return out
class TestSearchSpaceSkillsBackend:
def test_remaps_paths_when_listing(self) -> None:
listing = [
{"path": "/documents/_skills/policy", "is_dir": True},
{"path": "/documents/_skills/policy/SKILL.md", "is_dir": False},
{"path": "/documents/other-folder/x.md", "is_dir": False},
]
kb = _FakeKBBackend(listing=listing, file_contents={})
backend = SearchSpaceSkillsBackend(kb)
infos = asyncio.run(backend.als_info("/"))
assert kb.last_ls_path == "/documents/_skills"
paths = [info["path"] for info in infos]
assert "/policy" in paths
assert "/policy/SKILL.md" in paths
# Unrelated KB documents must NOT leak into the skills namespace.
assert all(not p.startswith("/documents") for p in paths)
def test_remaps_paths_when_downloading(self) -> None:
kb = _FakeKBBackend(
listing=[],
file_contents={
"/documents/_skills/policy/SKILL.md": b"---\nname: policy\n---\n",
},
)
backend = SearchSpaceSkillsBackend(kb)
responses = asyncio.run(backend.adownload_files(["/policy/SKILL.md"]))
assert kb.last_download_paths == ["/documents/_skills/policy/SKILL.md"]
assert responses[0].path == "/policy/SKILL.md"
assert responses[0].error is None
assert responses[0].content is not None
def test_sync_methods_raise_not_implemented(self) -> None:
backend = SearchSpaceSkillsBackend(_FakeKBBackend([], {}))
with pytest.raises(NotImplementedError):
backend.ls_info("/")
with pytest.raises(NotImplementedError):
backend.download_files(["/x"])
def test_custom_kb_root_is_honored(self) -> None:
kb = _FakeKBBackend(
listing=[
{"path": "/skills_admin/x", "is_dir": True},
],
file_contents={},
)
backend = SearchSpaceSkillsBackend(kb, kb_root="/skills_admin")
infos = asyncio.run(backend.als_info("/"))
assert kb.last_ls_path == "/skills_admin"
assert infos[0]["path"] == "/x"
class TestBackendFactory:
def test_builtin_only_factory_returns_composite(self) -> None:
factory = build_skills_backend_factory()
backend = factory(runtime=None) # type: ignore[arg-type]
from deepagents.backends.composite import CompositeBackend
assert isinstance(backend, CompositeBackend)
assert SKILLS_BUILTIN_PREFIX in backend.routes
assert SKILLS_SPACE_PREFIX not in backend.routes
def test_default_skills_sources_lists_builtin_then_space(self) -> None:
sources = default_skills_sources()
assert sources == [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]

View file

@ -0,0 +1,338 @@
"""Tests for the specialized subagents (explore / report_writer / connector_negotiator)."""
from __future__ import annotations
from langchain_core.tools import tool
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.subagents import (
build_connector_negotiator_subagent,
build_explore_subagent,
build_report_writer_subagent,
build_specialized_subagents,
)
from app.agents.new_chat.subagents.config import (
EXPLORE_READ_TOOLS,
REPORT_WRITER_TOOLS,
WRITE_TOOL_DENY_PATTERNS,
)
# ---------------------------------------------------------------------------
# Fake tools used to verify filtering & permission behavior
# ---------------------------------------------------------------------------
@tool
def search_surfsense_docs(query: str) -> str:
"""Search the user's KB."""
return ""
@tool
def web_search(query: str) -> str:
"""Search the public web."""
return ""
@tool
def scrape_webpage(url: str) -> str:
"""Scrape a single webpage."""
return ""
@tool
def read_file(path: str) -> str:
"""Read a file."""
return ""
@tool
def ls_tree(path: str) -> str:
"""List a tree."""
return ""
@tool
def grep(pattern: str) -> str:
"""Grep."""
return ""
@tool
def update_memory(content: str) -> str:
"""Update the user's memory."""
return ""
@tool
def edit_file(path: str, old: str, new: str) -> str:
"""Edit a file."""
return ""
@tool
def linear_create_issue(title: str) -> str:
"""Create a Linear issue."""
return ""
@tool
def slack_send_message(channel: str, text: str) -> str:
"""Send a Slack message."""
return ""
@tool
def get_connected_accounts() -> str:
"""List connected accounts."""
return ""
@tool
def generate_report(topic: str) -> str:
"""Generate a report artifact."""
return ""
ALL_TOOLS = [
search_surfsense_docs,
web_search,
scrape_webpage,
read_file,
ls_tree,
grep,
update_memory,
edit_file,
linear_create_issue,
slack_send_message,
get_connected_accounts,
generate_report,
]
class TestExploreSubagent:
def test_only_read_tools_are_exposed(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS}
assert "update_memory" not in names
assert "linear_create_issue" not in names
assert "edit_file" not in names
def test_includes_permission_middleware_with_deny_rules(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
assert len(permission_mws) == 1
ruleset = permission_mws[0]._static_rulesets[0]
assert ruleset.origin == "subagent_explore"
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
assert "update_memory" in deny_patterns
assert "edit_file" in deny_patterns
assert "*create*" in deny_patterns
assert "*send*" in deny_patterns
def test_skills_inherits_default_sources(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index]
def test_name_and_description_match_contract(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
assert spec["name"] == "explore"
assert "read-only" in spec["description"].lower()
def test_includes_dedup_and_patch_middleware(self) -> None:
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
spec = build_explore_subagent(tools=ALL_TOOLS)
types = {type(m) for m in spec["middleware"]} # type: ignore[index]
assert PatchToolCallsMiddleware in types
assert DedupHITLToolCallsMiddleware in types
class TestReportWriterSubagent:
def test_exposes_only_report_writing_tools(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS}
assert "generate_report" in names
assert "search_surfsense_docs" in names
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
assert "update_memory" in deny_patterns
# generate_report MUST not be denied — it's the whole point of the subagent.
assert "generate_report" not in deny_patterns
# No deny pattern should match `generate_report` either.
assert all(
not _wildcard_matches(pattern, "generate_report")
for pattern in deny_patterns
)
class TestConnectorNegotiatorSubagent:
def test_inherits_all_parent_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
# Every parent tool is inherited; the deny ruleset enforces behavior
# at execution time instead of trimming the tool list.
assert names == {t.name for t in ALL_TOOLS}
def test_get_connected_accounts_is_present(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert "get_connected_accounts" in names
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
# `linear_create_issue` matches the `*_create` deny pattern.
assert any(
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
)
assert any(
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
)
class TestBuildSpecializedSubagents:
def test_returns_three_specs(self) -> None:
specs = build_specialized_subagents(tools=ALL_TOOLS)
names = [s["name"] for s in specs] # type: ignore[index]
assert names == ["explore", "report_writer", "connector_negotiator"]
def test_all_specs_have_unique_names(self) -> None:
specs = build_specialized_subagents(tools=ALL_TOOLS)
names = [s["name"] for s in specs] # type: ignore[index]
assert len(set(names)) == len(names)
def test_extra_middleware_is_prepended_to_each_spec(self) -> None:
"""Sentinel middleware passed via ``extra_middleware`` must appear
in each subagent's ``middleware`` list, before the local rules.
This guards against the regression where specialized subagents
promised filesystem tools (``read_file``, ``ls``, ``grep``) in
their system prompts but had no filesystem middleware mounted.
"""
class _Sentinel:
pass
sentinel = _Sentinel()
specs = build_specialized_subagents(
tools=ALL_TOOLS, extra_middleware=[sentinel]
)
for spec in specs:
mws = spec["middleware"] # type: ignore[index]
assert sentinel in mws
# The sentinel must appear *before* the permission middleware
# (subagent-local rules), preserving the documented composition
# order: extra → custom → patch → dedup.
sentinel_idx = mws.index(sentinel)
perm_idx = next(
(i for i, m in enumerate(mws)
if isinstance(m, PermissionMiddleware)),
None,
)
assert perm_idx is not None
assert sentinel_idx < perm_idx
class TestFilterToolsWarningSuppression:
"""Names provided by middleware (read_file, ls, grep, …) must not
trigger the spurious "missing" warning in :func:`_filter_tools`."""
def test_middleware_provided_names_are_silent(self, caplog) -> None:
import logging
from app.agents.new_chat.subagents.config import _filter_tools
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
# Allowed set asks for two registry tools (one present, one
# not) plus a bunch of middleware-provided names.
_filter_tools(
[search_surfsense_docs],
allowed_names={
"search_surfsense_docs",
"scrape_webpage", # legitimately missing → should warn
"read_file", # mw-provided → suppressed
"ls",
"grep",
"glob",
"write_todos",
},
)
warnings = [
r.message for r in caplog.records if r.levelno >= logging.INFO
]
# Exactly one warning, and it should mention scrape_webpage but not
# any middleware-provided name. Inspect the rendered "missing"
# list (between the brackets) so we don't false-match substrings
# like ``ls`` inside ``available``.
assert len(warnings) == 1, warnings
msg = warnings[0]
assert "scrape_webpage" in msg
bracket_section = msg.split("missing: ", 1)[1]
for noisy in ("read_file", "ls", "grep", "glob", "write_todos"):
assert f"'{noisy}'" not in bracket_section, msg
class TestDenyPatternsCoverage:
def test_deny_patterns_cover_canonical_write_tools(self) -> None:
canonical_writes = [
"update_memory",
"edit_file",
"write_file",
"move_file",
"mkdir",
"linear_create_issue",
"linear_update_issue",
"linear_delete_issue",
"slack_send_message",
"create_index",
"update_account",
"delete_record",
"send_email",
]
for tool_name in canonical_writes:
assert any(
_wildcard_matches(pattern, tool_name)
for pattern in WRITE_TOOL_DENY_PATTERNS
), f"no deny pattern matches {tool_name!r}"
def test_deny_patterns_do_not_match_safe_read_tools(self) -> None:
canonical_reads = [
"search_surfsense_docs",
"read_file",
"ls_tree",
"grep",
"web_search",
"scrape_webpage",
"get_connected_accounts",
"generate_report",
]
for tool_name in canonical_reads:
assert not any(
_wildcard_matches(pattern, tool_name)
for pattern in WRITE_TOOL_DENY_PATTERNS
), f"deny pattern incorrectly matches read tool {tool_name!r}"
def _wildcard_matches(pattern: str, value: str) -> bool:
"""Helper using the same matcher the rule evaluator does."""
from app.agents.new_chat.permissions import wildcard_match
return wildcard_match(value, pattern)

View file

@ -0,0 +1,103 @@
"""Tests for ToolCallNameRepairMiddleware."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.new_chat.middleware.tool_call_repair import (
ToolCallNameRepairMiddleware,
)
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
pytestmark = pytest.mark.unit
def _make_state(message: AIMessage) -> dict:
return {"messages": [message]}
class _FakeRuntime:
def __init__(self, context: object | None = None) -> None:
self.context = context
class TestRepair:
def test_passthrough_when_name_matches(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "echo", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is None # no change
def test_lowercase_repair(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "Echo", "args": {"x": 1}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired = out["messages"][0]
assert repaired.tool_calls[0]["name"] == "echo"
def test_invalid_fallback_when_no_match(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo", INVALID_TOOL_NAME},
fuzzy_match_threshold=None,
)
msg = AIMessage(content="", tool_calls=[
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired_call = out["messages"][0].tool_calls[0]
assert repaired_call["name"] == INVALID_TOOL_NAME
assert repaired_call["args"]["tool"] == "totally_different_name"
assert "totally_different_name" in repaired_call["args"]["error"]
def test_no_invalid_means_skip_when_unknown(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "unknown", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
# No repair available; original returned unchanged (no update)
assert out is None
def test_fuzzy_match_works_when_enabled(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"search_documents"},
fuzzy_match_threshold=0.7,
)
msg = AIMessage(content="", tool_calls=[
{"name": "search_docments", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
def test_skips_when_no_messages(self) -> None:
mw = ToolCallNameRepairMiddleware(registered_tool_names={"echo"})
out = mw.after_model({"messages": []}, _FakeRuntime())
assert out is None
def test_runtime_context_extends_registered(self) -> None:
from types import SimpleNamespace
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "DynamicTool", "args": {}, "id": "1"},
])
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
out = mw.after_model(_make_state(msg), runtime)
assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "dynamictool"