feat: updated agent harness

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 09:22:19 -07:00
parent 9ec9b64348
commit 31a372bb84
139 changed files with 12583 additions and 1111 deletions

View file

@ -0,0 +1,146 @@
"""
Integration test harness for the SurfSense agent stack.
The plan calls for an ``LLMToolEmulator``-backed harness for end-to-end
replay of ``stream_new_chat``. The currently-installed langchain version
does not expose ``LLMToolEmulator``, so this harness builds the equivalent
on top of :class:`langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel`.
The harness lets a test author script a sequence of model responses
(text + optional tool calls) and replay them against the new_chat agent
graph. Tools are stubbed via ``StubToolSpec`` -> ``langchain_core.tools.tool``
decorator and execute deterministic Python callbacks.
Used by:
- ``tests/integration/agents/new_chat/test_feature_flag_smoke.py`` to
confirm the kill-switch path produces identical-shape output regardless
of which middleware flags are toggled.
- Future per-tier PRs to record golden transcripts.
"""
from __future__ import annotations
import uuid
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, tool
class _ToolBindingFakeChatModel(FakeMessagesListChatModel):
"""Adapter so the harness model can pretend it understands ``bind_tools``.
The base ``FakeMessagesListChatModel`` raises ``NotImplementedError`` from
``bind_tools``, but ``langchain.agents.create_agent`` always calls
``bind_tools`` to attach the tool registry. We don't actually need the
fake to honor the tool schema it's already scripted to emit the right
tool calls so we return self.
"""
def bind_tools( # type: ignore[override]
self,
tools: Sequence[Any],
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self
@dataclass
class StubToolSpec:
"""A test-mode tool: a name, description, and a deterministic body."""
name: str
description: str
handler: Callable[..., Any]
args_schema: dict[str, Any] | None = None
def build(self) -> BaseTool:
"""Realize as a `langchain_core.tools.BaseTool`."""
@tool(name_or_callable=self.name, description=self.description)
def _stub_tool(**kwargs: Any) -> Any:
return self.handler(**kwargs)
return _stub_tool
@dataclass
class ScriptedTurn:
"""One scripted assistant turn.
`text` is the assistant text (may be empty if pure tool call).
`tool_calls` is a list of dicts ``{name, args, id}``; if non-empty, the
agent will route to those tools and append a follow-up turn.
"""
text: str = ""
tool_calls: list[dict[str, Any]] = field(default_factory=list)
def build_scripted_messages(turns: list[ScriptedTurn]) -> list[BaseMessage]:
"""Convert :class:`ScriptedTurn` records to AIMessage payloads."""
out: list[BaseMessage] = []
for turn in turns:
tool_calls: list[dict[str, Any]] = []
for tc in turn.tool_calls:
tool_calls.append(
{
"name": tc["name"],
"args": tc.get("args", {}),
"id": tc.get("id") or f"call_{uuid.uuid4().hex[:8]}",
}
)
out.append(AIMessage(content=turn.text, tool_calls=tool_calls or []))
return out
@dataclass
class ScriptedHarness:
"""Bundle of (model, tools) ready to plug into ``create_agent``."""
model: _ToolBindingFakeChatModel
tools: list[BaseTool]
def build_scripted_harness(
*,
turns: list[ScriptedTurn],
tools: list[StubToolSpec] | None = None,
sleep: float | None = None,
) -> ScriptedHarness:
"""Construct a deterministic agent harness from a script.
Example::
harness = build_scripted_harness(
turns=[
ScriptedTurn(tool_calls=[{"name": "echo", "args": {"x": 1}}]),
ScriptedTurn(text="done"),
],
tools=[
StubToolSpec(name="echo", description="echo args", handler=lambda **kw: kw),
],
)
"""
messages = build_scripted_messages(turns)
model = _ToolBindingFakeChatModel(responses=messages, sleep=sleep)
realized_tools = [t.build() for t in (tools or [])]
return ScriptedHarness(model=model, tools=realized_tools)
__all__ = [
"ScriptedHarness",
"ScriptedTurn",
"StubToolSpec",
"build_scripted_harness",
"build_scripted_messages",
]

View file

@ -0,0 +1,53 @@
"""Smoke test: scripted harness drives create_agent end-to-end and produces a tool-call-then-final-text trace."""
from __future__ import annotations
import pytest
from langchain.agents import create_agent
from tests.integration.harness import (
ScriptedTurn,
StubToolSpec,
build_scripted_harness,
)
pytestmark = pytest.mark.integration
@pytest.mark.asyncio
async def test_scripted_harness_drives_basic_agent() -> None:
harness = build_scripted_harness(
turns=[
ScriptedTurn(
tool_calls=[
{"name": "echo", "args": {"x": 1}, "id": "call_1"},
]
),
ScriptedTurn(text="done"),
],
tools=[
StubToolSpec(
name="echo",
description="Echo args back.",
handler=lambda **kwargs: {"echoed": kwargs},
),
],
)
agent = create_agent(
harness.model,
system_prompt="You are a test agent.",
tools=harness.tools,
)
result = await agent.ainvoke({"messages": [("user", "do the thing")]})
messages = result["messages"]
final_ai = next(
(m for m in reversed(messages) if m.__class__.__name__ == "AIMessage"),
None,
)
assert final_ai is not None
assert final_ai.content == "done"
tool_messages = [m for m in messages if m.__class__.__name__ == "ToolMessage"]
assert len(tool_messages) == 1
assert "echoed" in str(tool_messages[0].content)

View file

@ -0,0 +1 @@

View file

@ -0,0 +1 @@

View file

@ -0,0 +1 @@
"""__init__ stub so pytest discovers the prompts test module."""

View file

@ -0,0 +1,201 @@
"""Tests for the prompt fragment composer (Tier 3a)."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.agents.new_chat.prompts.composer import (
ALL_TOOL_NAMES_ORDERED,
compose_system_prompt,
detect_provider_variant,
)
from app.db import ChatVisibility
pytestmark = pytest.mark.unit
@pytest.fixture
def fixed_today() -> datetime:
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
class TestProviderVariantDetection:
@pytest.mark.parametrize(
"model_name,expected",
[
("openai:gpt-4o-mini", "openai_classic"),
("openai:gpt-4-turbo", "openai_classic"),
("openai:gpt-5", "openai_reasoning"),
("openai:gpt-5-codex", "openai_reasoning"),
("openai:o1-preview", "openai_reasoning"),
("openai:o3-mini", "openai_reasoning"),
("anthropic:claude-3-5-sonnet", "anthropic"),
("anthropic/claude-opus-4", "anthropic"),
("google:gemini-2.0-flash", "google"),
("vertex:gemini-1.5-pro", "google"),
("groq:mixtral-8x7b", "default"),
(None, "default"),
("", "default"),
],
)
def test_detection(self, model_name: str | None, expected: str) -> None:
assert detect_provider_variant(model_name) == expected
class TestCompose:
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today)
# System instruction wrapper
assert "<system_instruction>" in prompt
assert "</system_instruction>" in prompt
# Date interpolated
assert "2025-06-01" in prompt
# Core policy blocks present
assert "<knowledge_base_only_policy>" in prompt
assert "<tool_routing>" in prompt
assert "<parameter_resolution>" in prompt
assert "<memory_protocol>" in prompt
# Tools
assert "<tools>" in prompt
assert "</tools>" in prompt
# Citations on by default
assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE,
)
# Team-specific phrasing in the agent block
assert "team space" in prompt
# Memory protocol mentions team
assert "team" in prompt
# Should NOT mention the user-only memory phrasing
assert "personal knowledge base" not in prompt
def test_private_visibility_uses_private_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.PRIVATE,
)
assert "personal knowledge base" in prompt
# Should NOT mention the team-specific phrasing about prefixed authors
assert "[DisplayName of the author]" not in prompt
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
assert "Citations are DISABLED" in prompt_off
assert "Citations are DISABLED" not in prompt_on
assert "[citation:chunk_id]" in prompt_on
def test_enabled_tool_filter_only_includes_listed_tools(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
)
assert "web_search:" in prompt or "- web_search:" in prompt
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
# Excluded tools should NOT appear in tool listing
assert "generate_podcast:" not in prompt
assert "generate_image:" not in prompt
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search"},
disabled_tool_names={"generate_image", "generate_podcast"},
)
assert "DISABLED TOOLS (by user):" in prompt
assert "Generate Image" in prompt
assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
)
assert "<mcp_tool_routing>" in prompt
assert "My GitLab" in prompt
assert "gitlab_search" in prompt
def test_mcp_routing_block_absent_when_no_servers(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
)
assert "<provider_hints>" in prompt
assert "Anthropic" in prompt or "Claude" in prompt
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
assert "<provider_hints>" not in prompt
def test_custom_system_instructions_override_default(
self, fixed_today: datetime
) -> None:
custom = "You are a custom assistant. Today is {resolved_today}."
prompt = compose_system_prompt(
today=fixed_today, custom_system_instructions=custom
)
assert "You are a custom assistant. Today is 2025-06-01." in prompt
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
use_default_system_instructions=False,
)
# No system_instruction wrapper but tools/citations still emitted
assert "<system_instruction>" not in prompt
assert "<tools>" in prompt
def test_all_known_tools_have_fragments(self) -> None:
# Soft assertion: verify that every tool in the canonical order
# produces non-empty content for at least one variant.
for tool in ALL_TOOL_NAMES_ORDERED:
prompt = compose_system_prompt(
today=datetime(2025, 1, 1, tzinfo=UTC),
enabled_tool_names={tool},
)
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
class TestStableOrderingForCacheStability:
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
def test_composition_is_deterministic_given_same_inputs(
self, fixed_today: datetime
) -> None:
a = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
b = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
assert a == b

View file

@ -0,0 +1,311 @@
"""Unit tests for ActionLogMiddleware (Tier 5.2)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.tools.registry import ToolDefinition
@dataclass
class _FakeRequest:
"""Minimal stand-in for ToolCallRequest used in unit tests."""
tool_call: dict[str, Any]
tool: Any = None
state: Any = None
runtime: Any = None
@tool
def make_widget(color: str, size: int) -> str:
"""Create a widget."""
return f"made {color} {size}"
def _enabled_flags(**overrides: bool) -> AgentFeatureFlags:
return AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
**overrides,
)
def _disabled_flags() -> AgentFeatureFlags:
return AgentFeatureFlags(disable_new_agent_stack=False, enable_action_log=False)
@pytest.fixture
def patch_get_flags():
def _patch(flags: AgentFeatureFlags):
return patch(
"app.agents.new_chat.middleware.action_log.get_flags",
return_value=flags,
)
return _patch
@pytest.fixture
def fake_session_factory():
"""Patch ``shielded_async_session`` with a recording fake."""
captured: dict[str, list] = {"rows": []}
class _FakeSession:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def add(self, row):
captured["rows"].append(row)
async def commit(self):
captured["committed"] = True
def _factory():
return _FakeSession()
return captured, _factory
class TestActionLogMiddlewareDisabled:
@pytest.mark.asyncio
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_disabled_flags()):
result = await mw.awrap_tool_call(request, handler)
handler.assert_awaited_once()
assert isinstance(result, ToolMessage)
@pytest.mark.asyncio
async def test_no_op_when_thread_id_none(self, patch_get_flags) -> None:
mw = ActionLogMiddleware(thread_id=None, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
with patch_get_flags(_enabled_flags()):
result = await mw.awrap_tool_call(request, handler)
assert isinstance(result, ToolMessage)
class TestActionLogMiddlewarePersistence:
@pytest.mark.asyncio
async def test_writes_row_on_success(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "red", "size": 3},
"id": "tc-abc",
},
)
result_msg = ToolMessage(
content="ok", tool_call_id="tc-abc", id="msg-1"
)
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
result = await mw.awrap_tool_call(request, handler)
assert result is result_msg
assert len(captured["rows"]) == 1
row = captured["rows"][0]
assert row.thread_id == 42
assert row.search_space_id == 7
assert row.user_id == "u1"
assert row.tool_name == "make_widget"
assert row.args == {"color": "red", "size": 3}
assert row.result_id == "msg-1"
assert row.error is None
assert row.reverse_descriptor is None
assert row.reversible is False
@pytest.mark.asyncio
async def test_writes_row_on_failure_and_reraises(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"color": "red"}, "id": "tc1"}
)
handler = AsyncMock(side_effect=ValueError("boom"))
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
), pytest.raises(ValueError, match="boom"):
await mw.awrap_tool_call(request, handler)
assert len(captured["rows"]) == 1
row = captured["rows"][0]
assert row.tool_name == "make_widget"
assert row.error == {"type": "ValueError", "message": "boom"}
assert row.result_id is None
@pytest.mark.asyncio
async def test_persistence_failure_does_not_break_tool_call(
self, patch_get_flags
) -> None:
"""Even if the DB write blows up, the tool's result must reach the model."""
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg)
def _exploding_session():
raise RuntimeError("DB is down")
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=_exploding_session
):
result = await mw.awrap_tool_call(request, handler)
assert result is result_msg
class TestReverseDescriptor:
@pytest.mark.asyncio
async def test_renders_reverse_descriptor_when_tool_declares_one(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
def _reverse(args, result):
return {"tool": "delete_widget", "args": {"id": result["id"]}}
tool_def = ToolDefinition(
name="make_widget",
description="Create a widget",
factory=lambda deps: make_widget,
reverse=_reverse,
)
mw = ActionLogMiddleware(
thread_id=1,
search_space_id=1,
user_id="u",
tool_definitions={"make_widget": tool_def},
)
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "blue", "size": 1},
"id": "tc-xyz",
},
)
result_msg = ToolMessage(
content='{"id": "widget-9"}', tool_call_id="tc-xyz", id="msg-9"
)
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is True
assert row.reverse_descriptor == {
"tool": "delete_widget",
"args": {"id": "widget-9"},
}
@pytest.mark.asyncio
async def test_swallows_reverse_callable_errors(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
def _bad_reverse(args, result):
raise RuntimeError("reverse blew up")
tool_def = ToolDefinition(
name="make_widget",
description="Create a widget",
factory=lambda deps: make_widget,
reverse=_bad_reverse,
)
mw = ActionLogMiddleware(
thread_id=1,
search_space_id=1,
user_id=None,
tool_definitions={"make_widget": tool_def},
)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
handler = AsyncMock(return_value=result_msg)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is False
assert row.reverse_descriptor is None
@pytest.mark.asyncio
async def test_no_reverse_when_tool_definition_missing(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.reversible is False
class TestArgsTruncation:
@pytest.mark.asyncio
async def test_huge_args_payload_is_truncated(
self, patch_get_flags, fake_session_factory
) -> None:
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
# Build a > 32KB string so the persisted payload triggers the truncation path.
huge = "x" * (40 * 1024)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
)
handler = AsyncMock(
return_value=ToolMessage(content="ok", tool_call_id="tc1")
)
with patch_get_flags(_enabled_flags()), patch(
"app.db.shielded_async_session", side_effect=lambda: factory()
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.args is not None
assert row.args.get("_truncated") is True
assert row.args.get("_size", 0) >= 40 * 1024

View file

@ -0,0 +1,90 @@
"""Tests for BusyMutexMiddleware: per-thread lock + cancel event behavior."""
from __future__ import annotations
import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
get_cancel_event,
manager,
request_cancel,
reset_cancel,
)
pytestmark = pytest.mark.unit
class _Runtime:
def __init__(self, thread_id: str | None) -> None:
self.config = {"configurable": {"thread_id": thread_id}}
@pytest.mark.asyncio
async def test_first_acquire_succeeds_and_release_unblocks() -> None:
mw = BusyMutexMiddleware()
runtime = _Runtime("t1")
await mw.abefore_agent({}, runtime)
# Lock should now be held
lock = manager.lock_for("t1")
assert lock.locked()
await mw.aafter_agent({}, runtime)
assert not lock.locked()
@pytest.mark.asyncio
async def test_second_concurrent_acquire_raises_busy() -> None:
mw_a = BusyMutexMiddleware()
mw_b = BusyMutexMiddleware()
runtime = _Runtime("t-conflict")
await mw_a.abefore_agent({}, runtime)
with pytest.raises(BusyError) as excinfo:
await mw_b.abefore_agent({}, runtime)
assert excinfo.value.request_id == "t-conflict"
await mw_a.aafter_agent({}, runtime)
# After release, mw_b can acquire
await mw_b.abefore_agent({}, runtime)
await mw_b.aafter_agent({}, runtime)
@pytest.mark.asyncio
async def test_cancel_event_lifecycle() -> None:
mw = BusyMutexMiddleware()
runtime = _Runtime("t-cancel")
await mw.abefore_agent({}, runtime)
event = get_cancel_event("t-cancel")
assert not event.is_set()
request_cancel("t-cancel")
assert event.is_set()
# End of turn should reset
await mw.aafter_agent({}, runtime)
assert not event.is_set()
@pytest.mark.asyncio
async def test_no_thread_id_raises_when_required() -> None:
mw = BusyMutexMiddleware(require_thread_id=True)
runtime = _Runtime(None)
with pytest.raises(BusyError):
await mw.abefore_agent({}, runtime)
@pytest.mark.asyncio
async def test_no_thread_id_skipped_when_not_required() -> None:
mw = BusyMutexMiddleware(require_thread_id=False)
runtime = _Runtime(None)
await mw.abefore_agent({}, runtime)
await mw.aafter_agent({}, runtime)
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")

View file

@ -0,0 +1,107 @@
"""Tests for SurfSenseCompactionMiddleware: protected SystemMessage handling and content sanitization."""
from __future__ import annotations
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from app.agents.new_chat.middleware.compaction import (
PROTECTED_SYSTEM_PREFIXES,
_is_protected_system_message,
_sanitize_message_content,
)
pytestmark = pytest.mark.unit
class TestIsProtectedSystemMessage:
@pytest.mark.parametrize("prefix", PROTECTED_SYSTEM_PREFIXES)
def test_each_prefix_protected(self, prefix: str) -> None:
msg = SystemMessage(content=f"{prefix}\nbody\n</close>")
assert _is_protected_system_message(msg) is True
def test_unprotected_system_message(self) -> None:
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
def test_human_message_never_protected(self) -> None:
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...")
assert _is_protected_system_message(msg) is True
class TestSanitizeMessageContent:
def test_returns_same_message_when_content_present(self) -> None:
msg = AIMessage(content="hello")
assert _sanitize_message_content(msg) is msg
def test_replaces_none_with_empty_string(self) -> None:
# Pydantic blocks ``content=None`` at construction; the real
# crash happens when the streaming layer mutates ``content``
# after-the-fact. Replicate that by force-setting on a built
# message.
msg = AIMessage(
content="",
tool_calls=[{"name": "x", "args": {}, "id": "1"}],
)
# Bypass pydantic validation to simulate the LiteLLM/Bedrock case
object.__setattr__(msg, "content", None)
sanitized = _sanitize_message_content(msg)
assert sanitized.content == ""
class TestPartitionMessages:
"""Verify the partition override surfaces protected SystemMessages
into ``preserved_messages`` regardless of cutoff position.
"""
def _build_partitioner(self):
# Construct a thin shim — we can't easily instantiate the full
# SurfSenseCompactionMiddleware without a real model, but the
# override path needs ``_lc_helper`` to delegate to. We mock
# that with a simple slicing partitioner equivalent to the real one.
from app.agents.new_chat.middleware.compaction import (
SurfSenseCompactionMiddleware,
)
class _LcHelper:
@staticmethod
def _partition_messages(messages, cutoff):
return messages[:cutoff], messages[cutoff:]
class _Stub(SurfSenseCompactionMiddleware):
def __init__(self):
self._lc_helper = _LcHelper()
return _Stub()
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
partitioner = self._build_partitioner()
protected = SystemMessage(content="<priority_documents>\n...")
msgs = [
HumanMessage(content="old human"),
AIMessage(content="old ai"),
protected,
ToolMessage(content="tool 1", tool_call_id="t1"),
HumanMessage(content="new"),
]
# Cutoff = 4 means everything before index 4 should be summarized
to_summary, preserved = partitioner._partition_messages(msgs, 4)
assert protected not in to_summary
assert protected in preserved
# The non-protected old messages remain in to_summary
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
def test_unprotected_messages_unaffected(self) -> None:
partitioner = self._build_partitioner()
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
to_summary, preserved = partitioner._partition_messages(msgs, 2)
assert [m.content for m in to_summary] == ["a", "b"]
assert [m.content for m in preserved] == ["c"]

View file

@ -0,0 +1,107 @@
"""Tests for SpillToBackendEdit and SpillingContextEditingMiddleware."""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.agents.new_chat.middleware.context_editing import (
SpillToBackendEdit,
_build_spill_placeholder,
)
pytestmark = pytest.mark.unit
def _build_history(num_pairs: int = 6) -> list[Any]:
"""Build a long history of (AIMessage with tool_call, ToolMessage) pairs."""
msgs: list[Any] = [HumanMessage(content="please do many things")]
for i in range(num_pairs):
msgs.append(
AIMessage(
content="",
tool_calls=[
{"name": f"tool_{i}", "args": {"i": i}, "id": f"call-{i}"},
],
)
)
msgs.append(
ToolMessage(
content="x" * 5000,
tool_call_id=f"call-{i}",
name=f"tool_{i}",
id=f"tool-msg-{i}",
)
)
return msgs
def _approx_count(messages: list[Any]) -> int:
"""Trivial token counter: 1 token per 4 chars."""
total = 0
for msg in messages:
content = getattr(msg, "content", "")
if isinstance(content, str):
total += len(content) // 4
return total
class TestSpillEdit:
def test_below_trigger_does_nothing(self) -> None:
edit = SpillToBackendEdit(trigger=1_000_000, keep=2)
msgs = _build_history(3)
original_lengths = [len(getattr(m, "content", "")) for m in msgs]
edit.apply(msgs, count_tokens=_approx_count)
new_lengths = [len(getattr(m, "content", "")) for m in msgs]
assert original_lengths == new_lengths
assert edit.pending_spills == []
def test_above_trigger_clears_and_records(self) -> None:
edit = SpillToBackendEdit(trigger=100, keep=1, path_prefix="/tool_outputs")
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
# The most-recent ToolMessage (keep=1) should remain intact
tool_messages = [m for m in msgs if isinstance(m, ToolMessage)]
intact = tool_messages[-1]
assert intact.content.startswith("x") # untouched
# Earlier ToolMessages should now contain the placeholder text
cleared = [
m for m in tool_messages
if isinstance(m.content, str) and m.content.startswith("[cleared")
]
assert len(cleared) >= 1
# And the spill list should match
assert len(edit.pending_spills) == len(cleared)
def test_excluded_tools_not_cleared(self) -> None:
edit = SpillToBackendEdit(
trigger=100,
keep=0,
exclude_tools=("tool_0",),
)
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
first_tool = next(
m for m in msgs if isinstance(m, ToolMessage) and m.name == "tool_0"
)
# Excluded — untouched
assert first_tool.content.startswith("x")
def test_drain_clears_pending(self) -> None:
edit = SpillToBackendEdit(trigger=100, keep=1)
msgs = _build_history(4)
edit.apply(msgs, count_tokens=_approx_count)
first_drain = edit.drain_pending()
assert len(first_drain) > 0
assert edit.drain_pending() == []
def test_placeholder_format(self) -> None:
path = "/tool_outputs/thread-1/tool-msg-0.txt"
text = _build_spill_placeholder(path)
assert path in text
assert "explore" in text # mentions the recovery agent

View file

@ -0,0 +1,132 @@
"""Tests for declarative dedup_key on ToolDefinition (Tier 2.3 migration)."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from langchain_core.tools import StructuredTool
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
pytestmark = pytest.mark.unit
def _make_tool(name: str, *, dedup_key=None, hitl_dedup_key=None):
metadata = {}
if dedup_key is not None:
metadata["dedup_key"] = dedup_key
if hitl_dedup_key is not None:
metadata["hitl"] = True
metadata["hitl_dedup_key"] = hitl_dedup_key
def _fn(**kwargs):
return "ok"
return StructuredTool.from_function(
func=_fn, name=name, description="x", metadata=metadata
)
def _msg(*calls: dict) -> AIMessage:
return AIMessage(content="", tool_calls=list(calls))
class _Runtime:
pass
def test_callable_dedup_key_takes_priority() -> None:
tool = _make_tool(
"create_doc",
dedup_key=lambda args: f"{args.get('parent_id')}::{args.get('title')}",
)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
new_calls = out["messages"][0].tool_calls
assert len(new_calls) == 2 # one duplicate dropped
assert {c["id"] for c in new_calls} == {"1", "3"}
def test_string_hitl_dedup_key_still_works() -> None:
tool = _make_tool("send_x", hitl_dedup_key="subject")
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{"name": "send_x", "args": {"subject": "Hello"}, "id": "1"},
{"name": "send_x", "args": {"subject": "hello"}, "id": "2"}, # case
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
assert len(out["messages"][0].tool_calls) == 1
def test_no_agent_tools_means_no_dedup() -> None:
"""After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS``
map, dedup is purely declarative no resolvers means no dedup runs.
Coverage for the previously hardcoded native HITL tools now lives on
each :class:`ToolDefinition.dedup_key` in
:mod:`app.agents.new_chat.tools.registry`, which is wired through to
``tool.metadata`` by :func:`build_tools`.
"""
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {
"messages": [
_msg(
{"name": "create_notion_page", "args": {"title": "X"}, "id": "1"},
{"name": "create_notion_page", "args": {"title": "x"}, "id": "2"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None
def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
"""Smoke-check the wiring path that replaced the legacy native map.
``ToolDefinition.dedup_key`` set in the registry must be copied onto
the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware`
can pick it up at agent build time.
"""
from app.agents.new_chat.tools.registry import (
BUILTIN_TOOLS,
wrap_dedup_key_by_arg_name,
)
notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"]
assert notion_tool_defs, "registry should still expose create_notion_page"
tool_def = notion_tool_defs[0]
assert tool_def.dedup_key is not None
# Same wrapping helper used in the registry — sanity check identity
sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"})
assert sample == "plan"
def test_unknown_tool_passes_through() -> None:
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {
"messages": [
_msg(
{"name": "anything_else", "args": {"x": 1}, "id": "1"},
{"name": "anything_else", "args": {"x": 1}, "id": "2"},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None # no dedup configured -> kept

View file

@ -0,0 +1,128 @@
"""Lock in the default-allow layering used by ``chat_deepagent``.
The agent factory wires ``PermissionMiddleware`` with three rulesets,
earliest -> latest:
1. ``surfsense_defaults`` (single ``allow */*`` rule)
2. ``connector_synthesized`` (deny rules for tools whose required
connector is missing)
3. (future) user-defined rules from the Agent Permissions UI
Without #1 every read-only built-in (``ls``, ``read_file``, ``grep``,
``glob``, ``web_search`` ) defaulted to ``ask`` because
``permissions.evaluate`` returns ``ask`` when no rule matches. That
caused two production-painful behaviors:
* Resume payloads with a prior reject decision bled into innocent
read-only tool calls, raising ``RejectedError("ls")``.
* Mutating connector tools got *double* prompted once via the
middleware ``ask`` and again via the per-tool ``interrupt()`` in
``app.agents.new_chat.tools.hitl``.
These tests pin the layering so a refactor that drops the default
ruleset fails loud.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
pytestmark = pytest.mark.unit
def _layered_rulesets(connector_denies: list[Rule]) -> list[Ruleset]:
"""Replicate ``chat_deepagent`` layering for the test."""
return [
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
Ruleset(rules=connector_denies, origin="connector_synthesized"),
]
class TestReadOnlyToolsAllowed:
"""Read-only built-ins must NOT default to ask."""
@pytest.mark.parametrize(
"tool_name",
[
"ls",
"read_file",
"grep",
"glob",
"web_search",
"scrape_webpage",
"search_surfsense_docs",
"get_connected_accounts",
"write_todos",
"task",
"_noop",
"invalid",
"update_memory",
],
)
def test_default_allow_covers_safe_builtin(self, tool_name: str) -> None:
rulesets = _layered_rulesets(connector_denies=[])
rules = evaluate_many(tool_name, [tool_name], *rulesets)
assert aggregate_action(rules) == "allow"
class TestConnectorDenyOverridesDefaultAllow:
"""Connector-synthesized denies must beat the default-allow rule."""
def test_missing_connector_tool_is_denied(self) -> None:
rulesets = _layered_rulesets(
connector_denies=[
Rule(permission="linear_create_issue", pattern="*", action="deny")
]
)
rules = evaluate_many(
"linear_create_issue", ["linear_create_issue"], *rulesets
)
assert aggregate_action(rules) == "deny"
def test_default_allow_still_applies_to_other_tools(self) -> None:
"""A deny rule for one tool must not bleed onto unrelated calls."""
rulesets = _layered_rulesets(
connector_denies=[
Rule(permission="linear_create_issue", pattern="*", action="deny")
]
)
rules = evaluate_many("ls", ["ls"], *rulesets)
assert aggregate_action(rules) == "allow"
class TestUserRuleOverridesDefault:
"""User rules layered last must override the default-allow rule."""
def test_user_ask_overrides_default_allow(self) -> None:
defaults = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
user_ruleset = Ruleset(
rules=[Rule(permission="ls", pattern="*", action="ask")],
origin="user",
)
rules = evaluate_many("ls", ["ls"], defaults, user_ruleset)
assert aggregate_action(rules) == "ask"
def test_user_deny_overrides_default_allow(self) -> None:
defaults = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
user_ruleset = Ruleset(
rules=[Rule(permission="send_*", pattern="*", action="deny")],
origin="user",
)
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
assert aggregate_action(rules) == "deny"

View file

@ -0,0 +1,99 @@
"""Tests for DoomLoopMiddleware signature equality detection."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature
pytestmark = pytest.mark.unit
def test_signature_is_stable_for_identical_args() -> None:
a = _signature("search", {"q": "hello", "n": 10})
b = _signature("search", {"n": 10, "q": "hello"})
assert a == b
def test_signature_changes_with_args() -> None:
a = _signature("search", {"q": "hello"})
b = _signature("search", {"q": "world"})
assert a != b
def test_signature_changes_with_name() -> None:
a = _signature("search", {"q": "x"})
b = _signature("read", {"q": "x"})
assert a != b
class _FakeRuntime:
def __init__(self, thread_id: str | None = "thread-1") -> None:
self.config = {"configurable": {"thread_id": thread_id}}
def _msg_calling(name: str, args: dict, call_id: str) -> AIMessage:
return AIMessage(
content="",
tool_calls=[{"name": name, "args": args, "id": call_id}],
)
def test_threshold_triggers_after_n_identical_calls() -> None:
mw = DoomLoopMiddleware(threshold=3)
runtime = _FakeRuntime()
# First two calls — under threshold
for i in range(2):
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, f"call-{i}")]},
runtime,
)
assert out is None
# Third identical call should trigger ``langgraph.types.interrupt``.
# In a unit-test context (no runnable graph), ``interrupt`` raises
# ``RuntimeError`` because ``get_config`` has nothing to bind to —
# we accept that as proof the interrupt path was taken (the
# alternative would be no exception, which would mean the loop
# detection never fired).
with pytest.raises(Exception) as excinfo:
mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, "call-3")]},
runtime,
)
name = type(excinfo.value).__name__.lower()
assert (
"interrupt" in name
or "runtimeerror" in name
), f"Expected an interrupt-style exception, got {name}"
def test_does_not_trigger_when_args_differ() -> None:
mw = DoomLoopMiddleware(threshold=2)
runtime = _FakeRuntime()
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
)
assert out is None
out = mw.after_model(
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
)
assert out is None
def test_separate_threads_have_independent_windows() -> None:
mw = DoomLoopMiddleware(threshold=2)
rt_a = _FakeRuntime(thread_id="A")
rt_b = _FakeRuntime(thread_id="B")
mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_a)
# thread B should NOT count thread A's call
out = mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_b)
assert out is None # not yet at threshold for B
def test_invalid_threshold_rejected() -> None:
with pytest.raises(ValueError):
DoomLoopMiddleware(threshold=1)

View file

@ -0,0 +1,120 @@
"""Tests for the agent feature-flag system."""
from __future__ import annotations
import pytest
from app.agents.new_chat.feature_flags import (
AgentFeatureFlags,
reload_for_tests,
)
pytestmark = pytest.mark.unit
def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
for name in [
"SURFSENSE_DISABLE_NEW_AGENT_STACK",
"SURFSENSE_ENABLE_CONTEXT_EDITING",
"SURFSENSE_ENABLE_COMPACTION_V2",
"SURFSENSE_ENABLE_RETRY_AFTER",
"SURFSENSE_ENABLE_MODEL_FALLBACK",
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
"SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
"SURFSENSE_ENABLE_DOOM_LOOP",
"SURFSENSE_ENABLE_PERMISSION",
"SURFSENSE_ENABLE_BUSY_MUTEX",
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"SURFSENSE_ENABLE_SKILLS",
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
]:
monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False
def test_master_kill_switch_overrides_individual_flags(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_DISABLE_NEW_AGENT_STACK", "true")
monkeypatch.setenv("SURFSENSE_ENABLE_CONTEXT_EDITING", "true")
monkeypatch.setenv("SURFSENSE_ENABLE_PERMISSION", "true")
flags = reload_for_tests()
assert flags.disable_new_agent_stack is True
assert flags.enable_context_editing is False
assert flags.enable_permission is False
assert flags.any_new_middleware_enabled() is False
@pytest.mark.parametrize("truthy", ["1", "true", "TRUE", "yes", "on"])
def test_individual_flags_truthy_values(
monkeypatch: pytest.MonkeyPatch, truthy: str
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", truthy)
flags = reload_for_tests()
assert flags.enable_retry_after is True
assert flags.any_new_middleware_enabled() is True
@pytest.mark.parametrize("falsy", ["0", "false", "no", "off", "", "garbage"])
def test_individual_flags_falsy_values(
monkeypatch: pytest.MonkeyPatch, falsy: str
) -> None:
_clear_all(monkeypatch)
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", falsy)
flags = reload_for_tests()
assert flags.enable_retry_after is False
def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flag_to_env = {
"enable_context_editing": "SURFSENSE_ENABLE_CONTEXT_EDITING",
"enable_compaction_v2": "SURFSENSE_ENABLE_COMPACTION_V2",
"enable_retry_after": "SURFSENSE_ENABLE_RETRY_AFTER",
"enable_model_fallback": "SURFSENSE_ENABLE_MODEL_FALLBACK",
"enable_model_call_limit": "SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
"enable_tool_call_limit": "SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
"enable_tool_call_repair": "SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
"enable_doom_loop": "SURFSENSE_ENABLE_DOOM_LOOP",
"enable_permission": "SURFSENSE_ENABLE_PERMISSION",
"enable_busy_mutex": "SURFSENSE_ENABLE_BUSY_MUTEX",
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true")
flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False

View file

@ -0,0 +1,119 @@
"""Tests for NoopInjectionMiddleware provider-compat logic."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.noop_injection import (
NOOP_TOOL_NAME,
NoopInjectionMiddleware,
_last_ai_has_tool_calls,
_provider_needs_noop,
)
pytestmark = pytest.mark.unit
class _LiteLLMModel:
def _get_ls_params(self):
return {"ls_provider": "litellm"}
class _BedrockModel:
def _get_ls_params(self):
return {"ls_provider": "bedrock"}
class _OpenAIModel:
def _get_ls_params(self):
return {"ls_provider": "openai"}
class _ChatLiteLLM: # name-only fallback
pass
class TestProviderDetection:
def test_litellm(self) -> None:
assert _provider_needs_noop(_LiteLLMModel()) is True
def test_bedrock(self) -> None:
assert _provider_needs_noop(_BedrockModel()) is True
def test_openai_does_not_need(self) -> None:
assert _provider_needs_noop(_OpenAIModel()) is False
def test_class_name_fallback(self) -> None:
assert _provider_needs_noop(_ChatLiteLLM()) is True
class TestHistoryDetection:
def test_last_ai_has_tool_calls(self) -> None:
msgs = [
HumanMessage(content="hi"),
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
]
assert _last_ai_has_tool_calls(msgs) is True
def test_last_ai_no_tool_calls(self) -> None:
msgs = [
HumanMessage(content="hi"),
AIMessage(content="hello"),
]
assert _last_ai_has_tool_calls(msgs) is False
def test_no_ai_in_history(self) -> None:
assert _last_ai_has_tool_calls([HumanMessage(content="hi")]) is False
class _FakeRequest:
def __init__(self, *, tools, messages, model) -> None:
self.tools = tools
self.messages = messages
self.model = model
def override(self, *, tools):
return _FakeRequest(tools=tools, messages=self.messages, model=self.model)
class TestShouldInject:
def test_injects_when_all_conditions_met(self) -> None:
mw = NoopInjectionMiddleware()
msgs = [
HumanMessage(content="hi"),
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
]
req = _FakeRequest(tools=[], messages=msgs, model=_LiteLLMModel())
assert mw._should_inject(req) is True
def test_skips_when_tools_present(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[object()],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
model=_LiteLLMModel(),
)
assert mw._should_inject(req) is False
def test_skips_when_no_history_tool_calls(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[],
messages=[HumanMessage(content="hi")],
model=_LiteLLMModel(),
)
assert mw._should_inject(req) is False
def test_skips_for_openai(self) -> None:
mw = NoopInjectionMiddleware()
req = _FakeRequest(
tools=[],
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
model=_OpenAIModel(),
)
assert mw._should_inject(req) is False
def test_noop_tool_name_is_underscore_noop() -> None:
assert NOOP_TOOL_NAME == "_noop"

View file

@ -0,0 +1,195 @@
"""Tests for the OtelSpanMiddleware adapter (Tier 3b)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.middleware.otel_span import (
OtelSpanMiddleware,
_annotate_model_response,
_annotate_tool_result,
_resolve_input_size,
_resolve_model_attrs,
_resolve_tool_name,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _disable_otel(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
from app.observability import otel as ot
ot.reload_for_tests()
yield
ot.reload_for_tests()
class TestResolveModelAttrs:
def test_extracts_model_name_and_provider(self) -> None:
request = MagicMock()
request.model = MagicMock(spec=["model_name", "provider"])
request.model.model_name = "gpt-4o-mini"
request.model.provider = "openai"
assert _resolve_model_attrs(request) == ("gpt-4o-mini", "openai")
def test_handles_missing_model(self) -> None:
request = MagicMock()
request.model = None
assert _resolve_model_attrs(request) == (None, None)
def test_falls_back_through_attribute_chain(self) -> None:
request = MagicMock()
request.model = MagicMock(spec=["model_id", "_llm_type"])
request.model.model_id = "claude-3-5-sonnet"
request.model._llm_type = "anthropic-chat"
model_id, provider = _resolve_model_attrs(request)
assert model_id == "claude-3-5-sonnet"
assert provider == "anthropic-chat"
class TestResolveToolName:
def test_prefers_request_tool_name(self) -> None:
request = MagicMock()
request.tool = MagicMock(name="ToolStub")
request.tool.name = "scrape_webpage"
assert _resolve_tool_name(request) == "scrape_webpage"
def test_falls_back_to_tool_call_name(self) -> None:
request = MagicMock()
request.tool = None
request.tool_call = {"name": "web_search", "args": {}}
assert _resolve_tool_name(request) == "web_search"
def test_unknown_when_nothing_resolves(self) -> None:
request = MagicMock()
request.tool = None
request.tool_call = {}
assert _resolve_tool_name(request) == "unknown"
class TestResolveInputSize:
def test_returns_repr_length_of_args(self) -> None:
request = MagicMock()
request.tool_call = {"args": {"query": "hello world"}}
size = _resolve_input_size(request)
assert isinstance(size, int)
assert size > 0
def test_handles_no_tool_call(self) -> None:
request = MagicMock()
request.tool_call = None
assert _resolve_input_size(request) is None
class TestAnnotateModelResponse:
def test_attaches_token_counts_when_present(self) -> None:
sp = MagicMock()
msg = AIMessage(
content="hello",
usage_metadata={
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
},
)
_annotate_model_response(sp, msg)
sp.set_attribute.assert_any_call("tokens.prompt", 100)
sp.set_attribute.assert_any_call("tokens.completion", 50)
sp.set_attribute.assert_any_call("tokens.total", 150)
def test_handles_response_with_no_metadata(self) -> None:
sp = MagicMock()
msg = AIMessage(content="hello")
# Should not raise even when usage_metadata is missing
_annotate_model_response(sp, msg)
class TestAnnotateToolResult:
def test_records_size_and_status(self) -> None:
sp = MagicMock()
result = ToolMessage(
content="result text",
tool_call_id="abc",
status="success",
)
_annotate_tool_result(sp, result)
sp.set_attribute.assert_any_call("tool.output.size", len("result text"))
sp.set_attribute.assert_any_call("tool.status", "success")
def test_marks_errors(self) -> None:
sp = MagicMock()
result = ToolMessage(
content="oops",
tool_call_id="abc",
additional_kwargs={"error": {"code": "x"}},
)
_annotate_tool_result(sp, result)
sp.set_attribute.assert_any_call("tool.error", True)
@pytest.mark.asyncio
class TestMiddlewareIntegration:
async def test_awrap_model_call_passes_through_when_disabled(self) -> None:
mw = OtelSpanMiddleware()
called: dict[str, Any] = {}
async def handler(req):
called["req"] = req
return AIMessage(content="ok")
request = MagicMock()
result = await mw.awrap_model_call(request, handler)
assert called["req"] is request
assert isinstance(result, AIMessage)
assert result.content == "ok"
async def test_awrap_tool_call_passes_through_when_disabled(self) -> None:
mw = OtelSpanMiddleware()
async def handler(req):
return ToolMessage(content="result", tool_call_id="abc")
request = MagicMock()
result = await mw.awrap_tool_call(request, handler)
assert isinstance(result, ToolMessage)
assert result.content == "result"
async def test_awrap_model_call_propagates_exceptions(self) -> None:
mw = OtelSpanMiddleware()
async def handler(req):
raise ValueError("boom")
with pytest.raises(ValueError):
await mw.awrap_model_call(MagicMock(), handler)
async def test_with_otel_enabled_does_not_alter_result(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
from app.observability import otel as ot
ot.reload_for_tests()
try:
mw = OtelSpanMiddleware()
async def handler(req):
return AIMessage(content="enabled")
request = MagicMock()
request.model = MagicMock()
request.model.model_name = "gpt-4o"
request.model.provider = "openai"
result = await mw.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert result.content == "enabled"
finally:
ot.reload_for_tests()

View file

@ -0,0 +1,116 @@
"""Tests for PermissionMiddleware end-to-end behavior."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.permissions import Rule, Ruleset
pytestmark = pytest.mark.unit
class _FakeRuntime:
config: dict = {"configurable": {"thread_id": "test"}}
def _msg(*tool_calls: dict) -> AIMessage:
return AIMessage(content="", tool_calls=list(tool_calls))
class TestAllow:
def test_passthrough_when_allow(self) -> None:
rs = Ruleset(rules=[Rule("send_email", "*", "allow")])
mw = PermissionMiddleware(rulesets=[rs])
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is None # no change
class TestDeny:
def test_replaces_with_deny_tool_message(self) -> None:
rs = Ruleset(rules=[Rule("send_email", "*", "deny")])
mw = PermissionMiddleware(rulesets=[rs])
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is not None
msgs = out["messages"]
# Find the deny ToolMessage
deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
assert len(deny_msgs) == 1
assert deny_msgs[0].status == "error"
assert "permission_denied" in str(deny_msgs[0].additional_kwargs)
# AIMessage's tool_calls should now be empty (denied call removed)
ai_msg = next(m for m in msgs if isinstance(m, AIMessage))
assert ai_msg.tool_calls == []
def test_mixed_allow_deny(self) -> None:
rs = Ruleset(
rules=[
Rule("send_email", "*", "deny"),
Rule("read", "*", "allow"),
]
)
mw = PermissionMiddleware(rulesets=[rs])
state = {
"messages": [
_msg(
{"name": "send_email", "args": {}, "id": "1"},
{"name": "read", "args": {}, "id": "2"},
)
]
}
out = mw.after_model(state, _FakeRuntime())
assert out is not None
ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage))
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "read"
class TestAsk:
def test_reject_without_feedback_raises(self) -> None:
# Default: nothing matches -> ask
rs = Ruleset(rules=[])
mw = PermissionMiddleware(rulesets=[rs])
# Bypass real interrupt — patch the helper
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())
def test_reject_with_feedback_raises_corrected(self) -> None:
rs = Ruleset(rules=[])
mw = PermissionMiddleware(rulesets=[rs])
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
"decision_type": "reject",
"feedback": "use a different subject line",
}
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(CorrectedError) as excinfo:
mw.after_model(state, _FakeRuntime())
assert excinfo.value.feedback == "use a different subject line"
def test_once_proceeds_without_persisting(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
# No state change because all calls kept
assert out is None
# No new rule persisted
assert mw._runtime_ruleset.rules == []
def test_always_persists_runtime_rule(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is None # call kept
# Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(
r.permission == "send_email" for r in new_rules
)

View file

@ -0,0 +1,111 @@
"""Tests for the wildcard matcher and rule evaluator (opencode evaluate.ts parity)."""
from __future__ import annotations
import pytest
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate,
evaluate_many,
wildcard_match,
)
pytestmark = pytest.mark.unit
class TestWildcardMatch:
@pytest.mark.parametrize(
"value,pattern,expected",
[
("edit", "edit", True),
("edit", "*", True),
("read", "edit", False),
("/documents/secrets/x", "/documents/secrets/**", True),
# Single-segment glob: '*' does not cross '/'
("/documents/secrets/x", "/documents/*/x", True),
("/documents/foo/bar/x", "/documents/*/x", False),
("/documents/foo/x", "/documents/*/x", True),
("linear_create", "linear_*", True),
("notion_create", "linear_*", False),
# ':' is not a separator, so '*' matches it
("mcp:notion:create_page", "mcp:*", True),
("mcp:notion:create_page", "mcp:**", True),
# But '/' IS a separator
("foo/bar", "foo/*", True),
("foo/bar/baz", "foo/*", False),
],
)
def test_match(self, value: str, pattern: str, expected: bool) -> None:
assert wildcard_match(value, pattern) is expected
class TestEvaluate:
def test_default_action_is_ask(self) -> None:
rule = evaluate("edit", "/foo/bar")
assert rule.action == "ask"
assert rule.permission == "edit"
def test_last_match_wins(self) -> None:
rs = Ruleset(
rules=[
Rule("edit", "*", "allow"),
Rule("edit", "/secrets/**", "deny"),
]
)
# Second rule (deny) is more specific AND specified later
assert evaluate("edit", "/secrets/x", rs).action == "deny"
# First rule (allow) covers the rest
assert evaluate("edit", "/public/x", rs).action == "allow"
def test_layered_rulesets_later_overrides_earlier(self) -> None:
defaults = Ruleset(rules=[Rule("edit", "*", "ask")], origin="defaults")
space = Ruleset(rules=[Rule("edit", "*", "allow")], origin="space")
thread = Ruleset(rules=[Rule("edit", "*", "deny")], origin="thread")
# All three layered: thread wins
assert evaluate("edit", "x", defaults, space, thread).action == "deny"
# Without thread: space wins
assert evaluate("edit", "x", defaults, space).action == "allow"
def test_permission_wildcard(self) -> None:
rs = Ruleset(rules=[Rule("linear_*", "*", "allow")])
assert evaluate("linear_create_issue", "x", rs).action == "allow"
assert evaluate("notion_create", "x", rs).action == "ask"
def test_pattern_wildcard(self) -> None:
rs = Ruleset(rules=[Rule("edit", "/documents/secrets/**", "deny")])
assert evaluate("edit", "/documents/secrets/foo", rs).action == "deny"
assert evaluate("edit", "/documents/public/foo", rs).action == "ask"
def test_evaluate_many(self) -> None:
rs = Ruleset(
rules=[
Rule("edit", "*", "allow"),
Rule("edit", "/secrets/*", "deny"),
]
)
results = evaluate_many("edit", ["/public/x", "/secrets/y"], rs)
assert [r.action for r in results] == ["allow", "deny"]
class TestAggregateAction:
def test_any_deny_means_deny(self) -> None:
rules = [
Rule("a", "*", "allow"),
Rule("a", "*", "deny"),
Rule("a", "*", "ask"),
]
assert aggregate_action(rules) == "deny"
def test_any_ask_means_ask_when_no_deny(self) -> None:
rules = [Rule("a", "*", "allow"), Rule("a", "*", "ask")]
assert aggregate_action(rules) == "ask"
def test_all_allow_means_allow(self) -> None:
rules = [Rule("a", "*", "allow"), Rule("a", "*", "allow")]
assert aggregate_action(rules) == "allow"
def test_empty_means_ask(self) -> None:
assert aggregate_action([]) == "ask"

View file

@ -0,0 +1,187 @@
"""Unit tests for the SurfSense plugin entry-point loader (Tier 6)."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from langchain.agents.middleware import AgentMiddleware
from app.agents.new_chat.plugin_loader import (
PLUGIN_ENTRY_POINT_GROUP,
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.plugins.year_substituter import (
_YearSubstituterMiddleware,
make_middleware as year_substituter_factory,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _DummyMiddleware(AgentMiddleware):
"""Trivial middleware used as the success-path return value."""
tools = ()
def _ctx() -> PluginContext:
return PluginContext.build(
search_space_id=1,
user_id="u",
thread_visibility="PRIVATE", # type: ignore[arg-type]
llm=MagicMock(),
)
class _FakeEntryPoint:
"""Stand-in for ``importlib.metadata.EntryPoint``."""
def __init__(self, name: str, factory) -> None:
self.name = name
self._factory = factory
def load(self):
return self._factory
# ---------------------------------------------------------------------------
# Loader behaviour
# ---------------------------------------------------------------------------
class TestPluginLoaderBasics:
def test_returns_empty_when_allowlist_is_empty(self) -> None:
assert load_plugin_middlewares(_ctx(), allowed_plugin_names=[]) == []
def test_skips_non_allowlisted_plugin(self) -> None:
called = []
def factory(_): # would be an obvious bug if called
called.append(True)
return _DummyMiddleware()
ep = _FakeEntryPoint("dangerous_plugin", factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
assert result == []
assert not called
def test_loads_allowlisted_plugin(self) -> None:
ep = _FakeEntryPoint("year_substituter", year_substituter_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"year_substituter"}
)
assert len(result) == 1
assert isinstance(result[0], _YearSubstituterMiddleware)
class TestPluginLoaderIsolation:
def test_factory_exception_is_isolated(self) -> None:
def crashing_factory(_):
raise RuntimeError("boom")
ep = _FakeEntryPoint("buggy", crashing_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"})
assert result == [] # construction continued without the plugin
def test_non_middleware_return_is_rejected(self) -> None:
def bad_factory(_):
return "not a middleware"
ep = _FakeEntryPoint("liar", bad_factory)
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[ep],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"})
assert result == []
def test_load_phase_exception_is_isolated(self) -> None:
class _BrokenEP:
name = "broken"
def load(self):
raise ImportError("cannot import")
with patch(
"app.agents.new_chat.plugin_loader.entry_points",
return_value=[_BrokenEP()],
):
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"})
assert result == []
def test_one_failure_does_not_block_others(self) -> None:
"""Two plugins; one crashes during factory; the other still loads."""
def crashing_factory(_):
raise RuntimeError("boom")
eps = [
_FakeEntryPoint("crashing", crashing_factory),
_FakeEntryPoint("ok", year_substituter_factory),
]
with patch(
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
):
result = load_plugin_middlewares(
_ctx(), allowed_plugin_names={"crashing", "ok"}
)
assert len(result) == 1
assert isinstance(result[0], _YearSubstituterMiddleware)
class TestAllowlistEnv:
def test_empty_env_returns_empty_set(self, monkeypatch) -> None:
monkeypatch.delenv("SURFSENSE_ALLOWED_PLUGINS", raising=False)
assert load_allowed_plugin_names_from_env() == set()
def test_parses_comma_separated_value(self, monkeypatch) -> None:
monkeypatch.setenv(
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
)
assert load_allowed_plugin_names_from_env() == {
"year_substituter",
"noisy",
}
class TestPluginContext:
def test_build_includes_required_fields(self) -> None:
llm = MagicMock()
ctx = PluginContext.build(
search_space_id=42,
user_id="user-1",
thread_visibility="PRIVATE", # type: ignore[arg-type]
llm=llm,
)
assert ctx["search_space_id"] == 42
assert ctx["user_id"] == "user-1"
assert ctx["llm"] is llm
def test_does_not_carry_secrets_or_db_session(self) -> None:
ctx = _ctx()
# If a future change tries to add these keys, this test will fail loudly.
for forbidden in ("api_key", "secret", "db_session", "session"):
assert forbidden not in ctx
class TestEntryPointGroup:
def test_group_name_matches_pyproject_convention(self) -> None:
# Plugins register under `surfsense.plugins`; this is part of our
# public contract for plugin authors.
assert PLUGIN_ENTRY_POINT_GROUP == "surfsense.plugins"

View file

@ -0,0 +1,107 @@
"""Tests for RetryAfterMiddleware Retry-After parsing and retry decision logic."""
from __future__ import annotations
import pytest
from app.agents.new_chat.middleware.retry_after import (
RetryAfterMiddleware,
_extract_retry_after_seconds,
_is_non_retryable,
)
pytestmark = pytest.mark.unit
class _FakeResponse:
def __init__(self, headers: dict[str, str]) -> None:
self.headers = headers
class _FakeRateLimit(Exception):
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
super().__init__(msg)
if headers is not None:
self.response = _FakeResponse(headers)
class TestExtractRetryAfter:
def test_seconds_header(self) -> None:
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
assert _extract_retry_after_seconds(exc) == 30.0
def test_milliseconds_header_overrides_seconds(self) -> None:
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
assert _extract_retry_after_seconds(exc) == 1.5
def test_case_insensitive(self) -> None:
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
assert _extract_retry_after_seconds(exc) == 12.0
def test_falls_back_to_message_regex(self) -> None:
exc = Exception("Please retry after 7 seconds")
assert _extract_retry_after_seconds(exc) == 7.0
def test_returns_none_when_no_hint(self) -> None:
exc = Exception("oops")
assert _extract_retry_after_seconds(exc) is None
def test_handles_missing_headers_attr(self) -> None:
exc = ValueError("no headers")
assert _extract_retry_after_seconds(exc) is None
class TestIsNonRetryable:
@pytest.mark.parametrize(
"name",
["ContextWindowExceededError", "AuthenticationError", "InvalidRequestError"],
)
def test_non_retryable_classes(self, name: str) -> None:
cls = type(name, (Exception,), {})
assert _is_non_retryable(cls("x")) is True
def test_generic_exception_is_retryable(self) -> None:
assert _is_non_retryable(RuntimeError("transient")) is False
class TestDelayCalculation:
def test_takes_max_of_backoff_and_header(self) -> None:
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
exc = _FakeRateLimit("rl", {"retry-after": "10"})
delay = mw._delay_for_attempt(0, exc)
assert delay == pytest.approx(10.0)
def test_uses_backoff_when_no_header(self) -> None:
mw = RetryAfterMiddleware(
max_retries=3, initial_delay=2.0, backoff_factor=2.0, jitter=False
)
delay = mw._delay_for_attempt(2, RuntimeError("transient"))
# 2 * 2^2 = 8
assert delay == pytest.approx(8.0)
def test_caps_at_max_delay(self) -> None:
mw = RetryAfterMiddleware(
max_retries=3,
initial_delay=10.0,
backoff_factor=10.0,
max_delay=15.0,
jitter=False,
)
delay = mw._delay_for_attempt(5, RuntimeError("x"))
assert delay <= 15.0
class TestShouldRetry:
def test_default_retries_generic(self) -> None:
mw = RetryAfterMiddleware()
assert mw._should_retry(RuntimeError("transient")) is True
def test_default_skips_non_retryable(self) -> None:
mw = RetryAfterMiddleware()
cls = type("ContextWindowExceededError", (Exception,), {})
assert mw._should_retry(cls("too big")) is False
def test_custom_retry_on(self) -> None:
mw = RetryAfterMiddleware(retry_on=lambda exc: isinstance(exc, ValueError))
assert mw._should_retry(ValueError()) is True
assert mw._should_retry(KeyError()) is False

View file

@ -0,0 +1,242 @@
"""Tests for the skills backends used by SurfSense's SkillsMiddleware."""
from __future__ import annotations
import asyncio
from pathlib import Path
import pytest
from app.agents.new_chat.middleware.skills_backends import (
SKILLS_BUILTIN_PREFIX,
SKILLS_SPACE_PREFIX,
BuiltinSkillsBackend,
SearchSpaceSkillsBackend,
build_skills_backend_factory,
default_skills_sources,
)
@pytest.fixture
def skills_root(tmp_path: Path) -> Path:
"""Build a small synthetic skill-tree used by the tests."""
root = tmp_path / "skills"
(root / "alpha").mkdir(parents=True)
(root / "alpha" / "SKILL.md").write_text(
"---\nname: alpha\ndescription: alpha skill\n---\n# Alpha\n"
)
(root / "beta").mkdir(parents=True)
(root / "beta" / "SKILL.md").write_text(
"---\nname: beta\ndescription: beta skill\n---\n# Beta\n"
)
(root / "_orphan_file.md").write_text("not a skill, just a stray file")
return root
class TestBuiltinSkillsBackendListing:
def test_lists_skill_directories_at_root(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = backend.ls_info("/")
names = {info["path"] for info in infos}
assert "/alpha" in names
assert "/beta" in names
assert "/_orphan_file.md" in names
for info in infos:
if info["path"] in {"/alpha", "/beta"}:
assert info["is_dir"] is True
def test_lists_skill_md_under_skill_directory(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = backend.ls_info("/alpha")
paths = {info["path"] for info in infos}
assert paths == {"/alpha/SKILL.md"}
assert infos[0]["is_dir"] is False
assert infos[0]["size"] > 0
def test_returns_empty_for_missing_path(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
assert backend.ls_info("/nonexistent") == []
def test_returns_empty_when_root_missing(self, tmp_path: Path) -> None:
backend = BuiltinSkillsBackend(tmp_path / "definitely-missing")
assert backend.ls_info("/") == []
assert backend.download_files(["/x/SKILL.md"])[0].error == "file_not_found"
def test_refuses_path_traversal(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
assert backend.ls_info("/../../../etc") == []
responses = backend.download_files(["/../../../etc/passwd"])
assert responses[0].error == "invalid_path"
class TestBuiltinSkillsBackendDownload:
def test_downloads_skill_md_content(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha/SKILL.md", "/beta/SKILL.md"])
assert len(responses) == 2
assert responses[0].path == "/alpha/SKILL.md"
assert responses[0].content is not None
assert b"name: alpha" in responses[0].content
assert responses[1].error is None
def test_marks_directory_as_is_directory_error(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha"])
assert responses[0].error == "is_directory"
def test_marks_missing_file_as_file_not_found(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
responses = backend.download_files(["/alpha/missing.md"])
assert responses[0].error == "file_not_found"
assert responses[0].content is None
def test_response_path_matches_input_for_correlation(
self, skills_root: Path
) -> None:
backend = BuiltinSkillsBackend(skills_root)
inputs = ["/alpha/SKILL.md", "/missing.md", "/beta/SKILL.md"]
responses = backend.download_files(inputs)
assert [r.path for r in responses] == inputs
class TestBuiltinSkillsBackendIntegration:
"""Mirror the call sequence the SkillsMiddleware actually uses."""
def test_skills_middleware_call_pattern(self, skills_root: Path) -> None:
backend = BuiltinSkillsBackend(skills_root)
infos = asyncio.run(backend.als_info("/"))
skill_dirs = [i["path"] for i in infos if i.get("is_dir")]
assert sorted(skill_dirs) == ["/alpha", "/beta"]
skill_md_paths = [f"{p}/SKILL.md" for p in skill_dirs]
responses = asyncio.run(backend.adownload_files(skill_md_paths))
assert all(r.error is None for r in responses)
assert all(r.content is not None for r in responses)
class TestBundledSkills:
def test_default_root_resolves_to_repo_skills_dir(self) -> None:
backend = BuiltinSkillsBackend()
assert backend.root.name == "builtin"
assert backend.root.parent.name == "skills"
def test_bundled_starter_skills_are_present(self) -> None:
backend = BuiltinSkillsBackend()
infos = backend.ls_info("/")
names = {info["path"].lstrip("/") for info in infos if info.get("is_dir")}
# Five starter skills required by the Tier 4 plan.
for required in (
"kb-research",
"report-writing",
"meeting-prep",
"slack-summary",
"email-drafting",
):
assert required in names, f"missing starter skill: {required}"
def test_each_starter_skill_has_valid_skill_md(self) -> None:
backend = BuiltinSkillsBackend()
infos = backend.ls_info("/")
skill_dirs = [info["path"] for info in infos if info.get("is_dir")]
for skill_dir in skill_dirs:
md_path = f"{skill_dir}/SKILL.md"
response = backend.download_files([md_path])[0]
assert response.error is None, f"missing SKILL.md in {skill_dir}"
content = response.content.decode("utf-8").replace("\r\n", "\n")
assert content.startswith("---\n"), f"missing frontmatter in {skill_dir}"
assert "\nname:" in content
assert "\ndescription:" in content
class _FakeKBBackend:
"""Stand-in for :class:`KBPostgresBackend` with the two methods we need."""
def __init__(self, listing: list[dict], file_contents: dict[str, bytes]) -> None:
self._listing = listing
self._file_contents = file_contents
self.last_ls_path: str | None = None
self.last_download_paths: list[str] | None = None
async def als_info(self, path: str):
self.last_ls_path = path
return self._listing
async def adownload_files(self, paths):
from deepagents.backends.protocol import FileDownloadResponse
self.last_download_paths = list(paths)
out: list[FileDownloadResponse] = []
for p in paths:
content = self._file_contents.get(p)
if content is None:
out.append(FileDownloadResponse(path=p, error="file_not_found"))
else:
out.append(FileDownloadResponse(path=p, content=content))
return out
class TestSearchSpaceSkillsBackend:
def test_remaps_paths_when_listing(self) -> None:
listing = [
{"path": "/documents/_skills/policy", "is_dir": True},
{"path": "/documents/_skills/policy/SKILL.md", "is_dir": False},
{"path": "/documents/other-folder/x.md", "is_dir": False},
]
kb = _FakeKBBackend(listing=listing, file_contents={})
backend = SearchSpaceSkillsBackend(kb)
infos = asyncio.run(backend.als_info("/"))
assert kb.last_ls_path == "/documents/_skills"
paths = [info["path"] for info in infos]
assert "/policy" in paths
assert "/policy/SKILL.md" in paths
# Unrelated KB documents must NOT leak into the skills namespace.
assert all(not p.startswith("/documents") for p in paths)
def test_remaps_paths_when_downloading(self) -> None:
kb = _FakeKBBackend(
listing=[],
file_contents={
"/documents/_skills/policy/SKILL.md": b"---\nname: policy\n---\n",
},
)
backend = SearchSpaceSkillsBackend(kb)
responses = asyncio.run(backend.adownload_files(["/policy/SKILL.md"]))
assert kb.last_download_paths == ["/documents/_skills/policy/SKILL.md"]
assert responses[0].path == "/policy/SKILL.md"
assert responses[0].error is None
assert responses[0].content is not None
def test_sync_methods_raise_not_implemented(self) -> None:
backend = SearchSpaceSkillsBackend(_FakeKBBackend([], {}))
with pytest.raises(NotImplementedError):
backend.ls_info("/")
with pytest.raises(NotImplementedError):
backend.download_files(["/x"])
def test_custom_kb_root_is_honored(self) -> None:
kb = _FakeKBBackend(
listing=[
{"path": "/skills_admin/x", "is_dir": True},
],
file_contents={},
)
backend = SearchSpaceSkillsBackend(kb, kb_root="/skills_admin")
infos = asyncio.run(backend.als_info("/"))
assert kb.last_ls_path == "/skills_admin"
assert infos[0]["path"] == "/x"
class TestBackendFactory:
def test_builtin_only_factory_returns_composite(self) -> None:
factory = build_skills_backend_factory()
backend = factory(runtime=None) # type: ignore[arg-type]
from deepagents.backends.composite import CompositeBackend
assert isinstance(backend, CompositeBackend)
assert SKILLS_BUILTIN_PREFIX in backend.routes
assert SKILLS_SPACE_PREFIX not in backend.routes
def test_default_skills_sources_lists_builtin_then_space(self) -> None:
sources = default_skills_sources()
assert sources == [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]

View file

@ -0,0 +1,338 @@
"""Tests for the specialized subagents (explore / report_writer / connector_negotiator)."""
from __future__ import annotations
from langchain_core.tools import tool
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.subagents import (
build_connector_negotiator_subagent,
build_explore_subagent,
build_report_writer_subagent,
build_specialized_subagents,
)
from app.agents.new_chat.subagents.config import (
EXPLORE_READ_TOOLS,
REPORT_WRITER_TOOLS,
WRITE_TOOL_DENY_PATTERNS,
)
# ---------------------------------------------------------------------------
# Fake tools used to verify filtering & permission behavior
# ---------------------------------------------------------------------------
@tool
def search_surfsense_docs(query: str) -> str:
"""Search the user's KB."""
return ""
@tool
def web_search(query: str) -> str:
"""Search the public web."""
return ""
@tool
def scrape_webpage(url: str) -> str:
"""Scrape a single webpage."""
return ""
@tool
def read_file(path: str) -> str:
"""Read a file."""
return ""
@tool
def ls_tree(path: str) -> str:
"""List a tree."""
return ""
@tool
def grep(pattern: str) -> str:
"""Grep."""
return ""
@tool
def update_memory(content: str) -> str:
"""Update the user's memory."""
return ""
@tool
def edit_file(path: str, old: str, new: str) -> str:
"""Edit a file."""
return ""
@tool
def linear_create_issue(title: str) -> str:
"""Create a Linear issue."""
return ""
@tool
def slack_send_message(channel: str, text: str) -> str:
"""Send a Slack message."""
return ""
@tool
def get_connected_accounts() -> str:
"""List connected accounts."""
return ""
@tool
def generate_report(topic: str) -> str:
"""Generate a report artifact."""
return ""
ALL_TOOLS = [
search_surfsense_docs,
web_search,
scrape_webpage,
read_file,
ls_tree,
grep,
update_memory,
edit_file,
linear_create_issue,
slack_send_message,
get_connected_accounts,
generate_report,
]
class TestExploreSubagent:
def test_only_read_tools_are_exposed(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS}
assert "update_memory" not in names
assert "linear_create_issue" not in names
assert "edit_file" not in names
def test_includes_permission_middleware_with_deny_rules(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
assert len(permission_mws) == 1
ruleset = permission_mws[0]._static_rulesets[0]
assert ruleset.origin == "subagent_explore"
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
assert "update_memory" in deny_patterns
assert "edit_file" in deny_patterns
assert "*create*" in deny_patterns
assert "*send*" in deny_patterns
def test_skills_inherits_default_sources(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index]
def test_name_and_description_match_contract(self) -> None:
spec = build_explore_subagent(tools=ALL_TOOLS)
assert spec["name"] == "explore"
assert "read-only" in spec["description"].lower()
def test_includes_dedup_and_patch_middleware(self) -> None:
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
spec = build_explore_subagent(tools=ALL_TOOLS)
types = {type(m) for m in spec["middleware"]} # type: ignore[index]
assert PatchToolCallsMiddleware in types
assert DedupHITLToolCallsMiddleware in types
class TestReportWriterSubagent:
def test_exposes_only_report_writing_tools(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS}
assert "generate_report" in names
assert "search_surfsense_docs" in names
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
spec = build_report_writer_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
assert "update_memory" in deny_patterns
# generate_report MUST not be denied — it's the whole point of the subagent.
assert "generate_report" not in deny_patterns
# No deny pattern should match `generate_report` either.
assert all(
not _wildcard_matches(pattern, "generate_report")
for pattern in deny_patterns
)
class TestConnectorNegotiatorSubagent:
def test_inherits_all_parent_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
# Every parent tool is inherited; the deny ruleset enforces behavior
# at execution time instead of trimming the tool list.
assert names == {t.name for t in ALL_TOOLS}
def test_get_connected_accounts_is_present(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
names = {t.name for t in spec["tools"]} # type: ignore[index]
assert "get_connected_accounts" in names
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
permission_mws = [
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
]
ruleset = permission_mws[0]._static_rulesets[0]
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
# `linear_create_issue` matches the `*_create` deny pattern.
assert any(
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
)
assert any(
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
)
class TestBuildSpecializedSubagents:
def test_returns_three_specs(self) -> None:
specs = build_specialized_subagents(tools=ALL_TOOLS)
names = [s["name"] for s in specs] # type: ignore[index]
assert names == ["explore", "report_writer", "connector_negotiator"]
def test_all_specs_have_unique_names(self) -> None:
specs = build_specialized_subagents(tools=ALL_TOOLS)
names = [s["name"] for s in specs] # type: ignore[index]
assert len(set(names)) == len(names)
def test_extra_middleware_is_prepended_to_each_spec(self) -> None:
"""Sentinel middleware passed via ``extra_middleware`` must appear
in each subagent's ``middleware`` list, before the local rules.
This guards against the regression where specialized subagents
promised filesystem tools (``read_file``, ``ls``, ``grep``) in
their system prompts but had no filesystem middleware mounted.
"""
class _Sentinel:
pass
sentinel = _Sentinel()
specs = build_specialized_subagents(
tools=ALL_TOOLS, extra_middleware=[sentinel]
)
for spec in specs:
mws = spec["middleware"] # type: ignore[index]
assert sentinel in mws
# The sentinel must appear *before* the permission middleware
# (subagent-local rules), preserving the documented composition
# order: extra → custom → patch → dedup.
sentinel_idx = mws.index(sentinel)
perm_idx = next(
(i for i, m in enumerate(mws)
if isinstance(m, PermissionMiddleware)),
None,
)
assert perm_idx is not None
assert sentinel_idx < perm_idx
class TestFilterToolsWarningSuppression:
"""Names provided by middleware (read_file, ls, grep, …) must not
trigger the spurious "missing" warning in :func:`_filter_tools`."""
def test_middleware_provided_names_are_silent(self, caplog) -> None:
import logging
from app.agents.new_chat.subagents.config import _filter_tools
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
# Allowed set asks for two registry tools (one present, one
# not) plus a bunch of middleware-provided names.
_filter_tools(
[search_surfsense_docs],
allowed_names={
"search_surfsense_docs",
"scrape_webpage", # legitimately missing → should warn
"read_file", # mw-provided → suppressed
"ls",
"grep",
"glob",
"write_todos",
},
)
warnings = [
r.message for r in caplog.records if r.levelno >= logging.INFO
]
# Exactly one warning, and it should mention scrape_webpage but not
# any middleware-provided name. Inspect the rendered "missing"
# list (between the brackets) so we don't false-match substrings
# like ``ls`` inside ``available``.
assert len(warnings) == 1, warnings
msg = warnings[0]
assert "scrape_webpage" in msg
bracket_section = msg.split("missing: ", 1)[1]
for noisy in ("read_file", "ls", "grep", "glob", "write_todos"):
assert f"'{noisy}'" not in bracket_section, msg
class TestDenyPatternsCoverage:
def test_deny_patterns_cover_canonical_write_tools(self) -> None:
canonical_writes = [
"update_memory",
"edit_file",
"write_file",
"move_file",
"mkdir",
"linear_create_issue",
"linear_update_issue",
"linear_delete_issue",
"slack_send_message",
"create_index",
"update_account",
"delete_record",
"send_email",
]
for tool_name in canonical_writes:
assert any(
_wildcard_matches(pattern, tool_name)
for pattern in WRITE_TOOL_DENY_PATTERNS
), f"no deny pattern matches {tool_name!r}"
def test_deny_patterns_do_not_match_safe_read_tools(self) -> None:
canonical_reads = [
"search_surfsense_docs",
"read_file",
"ls_tree",
"grep",
"web_search",
"scrape_webpage",
"get_connected_accounts",
"generate_report",
]
for tool_name in canonical_reads:
assert not any(
_wildcard_matches(pattern, tool_name)
for pattern in WRITE_TOOL_DENY_PATTERNS
), f"deny pattern incorrectly matches read tool {tool_name!r}"
def _wildcard_matches(pattern: str, value: str) -> bool:
"""Helper using the same matcher the rule evaluator does."""
from app.agents.new_chat.permissions import wildcard_match
return wildcard_match(value, pattern)

View file

@ -0,0 +1,103 @@
"""Tests for ToolCallNameRepairMiddleware."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.new_chat.middleware.tool_call_repair import (
ToolCallNameRepairMiddleware,
)
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
pytestmark = pytest.mark.unit
def _make_state(message: AIMessage) -> dict:
return {"messages": [message]}
class _FakeRuntime:
def __init__(self, context: object | None = None) -> None:
self.context = context
class TestRepair:
def test_passthrough_when_name_matches(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "echo", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is None # no change
def test_lowercase_repair(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "Echo", "args": {"x": 1}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired = out["messages"][0]
assert repaired.tool_calls[0]["name"] == "echo"
def test_invalid_fallback_when_no_match(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo", INVALID_TOOL_NAME},
fuzzy_match_threshold=None,
)
msg = AIMessage(content="", tool_calls=[
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
repaired_call = out["messages"][0].tool_calls[0]
assert repaired_call["name"] == INVALID_TOOL_NAME
assert repaired_call["args"]["tool"] == "totally_different_name"
assert "totally_different_name" in repaired_call["args"]["error"]
def test_no_invalid_means_skip_when_unknown(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "unknown", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
# No repair available; original returned unchanged (no update)
assert out is None
def test_fuzzy_match_works_when_enabled(self) -> None:
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"search_documents"},
fuzzy_match_threshold=0.7,
)
msg = AIMessage(content="", tool_calls=[
{"name": "search_docments", "args": {}, "id": "1"},
])
out = mw.after_model(_make_state(msg), _FakeRuntime())
assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
def test_skips_when_no_messages(self) -> None:
mw = ToolCallNameRepairMiddleware(registered_tool_names={"echo"})
out = mw.after_model({"messages": []}, _FakeRuntime())
assert out is None
def test_runtime_context_extends_registered(self) -> None:
from types import SimpleNamespace
mw = ToolCallNameRepairMiddleware(
registered_tool_names={"echo"}, fuzzy_match_threshold=None
)
msg = AIMessage(content="", tool_calls=[
{"name": "DynamicTool", "args": {}, "id": "1"},
])
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
out = mw.after_model(_make_state(msg), runtime)
assert out is not None
assert out["messages"][0].tool_calls[0]["name"] == "dynamictool"

View file

@ -1,8 +1,10 @@
import pytest
from langchain_core.messages import AIMessage
from langchain_core.tools import StructuredTool
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
wrap_dedup_key_by_arg_name,
)
pytestmark = pytest.mark.unit
@ -14,9 +16,34 @@ def _make_state(tool_calls: list[dict]) -> dict:
return {"messages": [msg]}
def _hitl_tool(name: str, *, dedup_arg: str) -> StructuredTool:
"""Build a tool with declarative ``dedup_key`` metadata.
Mirrors the ``ToolDefinition.dedup_key`` -> ``tool.metadata["dedup_key"]``
propagation done by :func:`build_tools` after the cleanup tier.
"""
def _fn(**kwargs):
return "ok"
return StructuredTool.from_function(
func=_fn,
name=name,
description="x",
metadata={"dedup_key": wrap_dedup_key_by_arg_name(dedup_arg)},
)
def test_duplicate_hitl_calls_reduced_to_first():
"""When the LLM emits the same HITL tool call twice, only the first is kept."""
mw = DedupHITLToolCallsMiddleware()
"""When the LLM emits the same HITL tool call twice, only the first is kept.
After the cleanup tier removed ``_NATIVE_HITL_TOOL_DEDUP_KEYS``, the
resolver is sourced from ``ToolDefinition.dedup_key`` propagated onto
``tool.metadata`` which the registry does at agent build time. The
test mirrors that wiring with an in-memory tool.
"""
tool = _hitl_tool("delete_calendar_event", dedup_arg="event_title_or_id")
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = _make_state(
[

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,84 @@
"""Tests for the SurfSense OpenTelemetry shim (Tier 3b)."""
from __future__ import annotations
import pytest
from app.observability import otel
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _reset_otel_state(monkeypatch: pytest.MonkeyPatch):
"""Force a clean OTel disabled state per test, then restore after."""
for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"):
monkeypatch.delenv(env, raising=False)
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
otel.reload_for_tests()
yield
otel.reload_for_tests()
def test_disabled_by_default_when_no_endpoint() -> None:
assert otel.is_enabled() is False
def test_enabled_when_endpoint_configured(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
assert otel.reload_for_tests() is True
def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
assert otel.reload_for_tests() is False
class TestNoopSpansWhenDisabled:
def test_generic_span_yields_noop(self) -> None:
with otel.span("any.thing", attributes={"x": 1}) as sp:
sp.set_attribute("y", 2)
sp.set_attributes({"a": "b"})
sp.add_event("evt")
sp.record_exception(RuntimeError("ignored"))
sp.set_status("ignored")
# Reaching here without raising means the no-op is well-formed
def test_exception_propagates_through_span(self) -> None:
with pytest.raises(ValueError), otel.span("err"):
raise ValueError("boom")
def test_each_helper_is_a_no_op_when_disabled(self) -> None:
helpers = [
otel.tool_call_span("write_file", input_size=42),
otel.model_call_span(model_id="openai:gpt-4o", provider="openai"),
otel.kb_search_span(search_space_id=1, query_chars=99),
otel.kb_persist_span(document_type="NOTE", document_id=7),
otel.compaction_span(reason="overflow", messages_in=120),
otel.interrupt_span(interrupt_type="permission_ask"),
otel.permission_asked_span(permission="edit", pattern="/x/**"),
]
for cm in helpers:
with cm as sp:
assert sp is not None
sp.set_attribute("ok", True)
class TestEnabledIntegration:
"""When OTel is wired but no SDK exporter is bound, the API still works."""
def test_span_attaches_attributes(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Use the API tracer (no-op-ish but real Span objects).
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
assert otel.reload_for_tests() is True
# Should not raise even when set_attributes/record_exception fall through
# to an SDK that isn't actually installed.
with otel.tool_call_span("scrape_webpage", input_size=10) as sp:
sp.set_attribute("tool.output.size", 200)
sp.set_attribute("tool.truncated", False)
with otel.model_call_span(model_id="m", provider="p") as sp:
sp.set_attribute("retry.count", 3)

View file

@ -0,0 +1,56 @@
"""Unit tests for the agent revert service (Tier 5.3)."""
from __future__ import annotations
from typing import Any
from app.services.revert_service import can_revert
class _FakeAction:
def __init__(self, *, user_id: Any, tool_name: str = "edit_file") -> None:
self.user_id = user_id
self.tool_name = tool_name
class TestCanRevert:
def test_owner_can_revert_their_own_action(self) -> None:
action = _FakeAction(user_id="user-123")
assert can_revert(
requester_user_id="user-123", action=action, is_admin=False
)
def test_other_user_cannot_revert(self) -> None:
action = _FakeAction(user_id="user-123")
assert not can_revert(
requester_user_id="someone-else", action=action, is_admin=False
)
def test_admin_always_allowed(self) -> None:
action = _FakeAction(user_id="user-123")
assert can_revert(
requester_user_id="anybody", action=action, is_admin=True
)
def test_admin_can_revert_anonymous_action(self) -> None:
action = _FakeAction(user_id=None)
assert can_revert(
requester_user_id="admin", action=action, is_admin=True
)
def test_anonymous_action_blocks_non_admin(self) -> None:
action = _FakeAction(user_id=None)
assert not can_revert(
requester_user_id="user-1", action=action, is_admin=False
)
def test_uuid_string_normalization(self) -> None:
"""``user_id`` may be a UUID object; comparison should still work."""
import uuid
u = uuid.uuid4()
action = _FakeAction(user_id=u)
# Same UUID, passed as string from the requesting side.
assert can_revert(
requester_user_id=str(u), action=action, is_admin=False
)