mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat: updated agent harness
This commit is contained in:
parent
9ec9b64348
commit
31a372bb84
139 changed files with 12583 additions and 1111 deletions
146
surfsense_backend/tests/integration/harness/__init__.py
Normal file
146
surfsense_backend/tests/integration/harness/__init__.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
Integration test harness for the SurfSense agent stack.
|
||||
|
||||
The plan calls for an ``LLMToolEmulator``-backed harness for end-to-end
|
||||
replay of ``stream_new_chat``. The currently-installed langchain version
|
||||
does not expose ``LLMToolEmulator``, so this harness builds the equivalent
|
||||
on top of :class:`langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel`.
|
||||
|
||||
The harness lets a test author script a sequence of model responses
|
||||
(text + optional tool calls) and replay them against the new_chat agent
|
||||
graph. Tools are stubbed via ``StubToolSpec`` -> ``langchain_core.tools.tool``
|
||||
decorator and execute deterministic Python callbacks.
|
||||
|
||||
Used by:
|
||||
- ``tests/integration/agents/new_chat/test_feature_flag_smoke.py`` to
|
||||
confirm the kill-switch path produces identical-shape output regardless
|
||||
of which middleware flags are toggled.
|
||||
- Future per-tier PRs to record golden transcripts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeMessagesListChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
|
||||
class _ToolBindingFakeChatModel(FakeMessagesListChatModel):
|
||||
"""Adapter so the harness model can pretend it understands ``bind_tools``.
|
||||
|
||||
The base ``FakeMessagesListChatModel`` raises ``NotImplementedError`` from
|
||||
``bind_tools``, but ``langchain.agents.create_agent`` always calls
|
||||
``bind_tools`` to attach the tool registry. We don't actually need the
|
||||
fake to honor the tool schema — it's already scripted to emit the right
|
||||
tool calls — so we return self.
|
||||
"""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubToolSpec:
|
||||
"""A test-mode tool: a name, description, and a deterministic body."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
handler: Callable[..., Any]
|
||||
args_schema: dict[str, Any] | None = None
|
||||
|
||||
def build(self) -> BaseTool:
|
||||
"""Realize as a `langchain_core.tools.BaseTool`."""
|
||||
|
||||
@tool(name_or_callable=self.name, description=self.description)
|
||||
def _stub_tool(**kwargs: Any) -> Any:
|
||||
return self.handler(**kwargs)
|
||||
|
||||
return _stub_tool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptedTurn:
|
||||
"""One scripted assistant turn.
|
||||
|
||||
`text` is the assistant text (may be empty if pure tool call).
|
||||
`tool_calls` is a list of dicts ``{name, args, id}``; if non-empty, the
|
||||
agent will route to those tools and append a follow-up turn.
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def build_scripted_messages(turns: list[ScriptedTurn]) -> list[BaseMessage]:
|
||||
"""Convert :class:`ScriptedTurn` records to AIMessage payloads."""
|
||||
out: list[BaseMessage] = []
|
||||
for turn in turns:
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for tc in turn.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": tc["name"],
|
||||
"args": tc.get("args", {}),
|
||||
"id": tc.get("id") or f"call_{uuid.uuid4().hex[:8]}",
|
||||
}
|
||||
)
|
||||
out.append(AIMessage(content=turn.text, tool_calls=tool_calls or []))
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptedHarness:
|
||||
"""Bundle of (model, tools) ready to plug into ``create_agent``."""
|
||||
|
||||
model: _ToolBindingFakeChatModel
|
||||
tools: list[BaseTool]
|
||||
|
||||
|
||||
def build_scripted_harness(
|
||||
*,
|
||||
turns: list[ScriptedTurn],
|
||||
tools: list[StubToolSpec] | None = None,
|
||||
sleep: float | None = None,
|
||||
) -> ScriptedHarness:
|
||||
"""Construct a deterministic agent harness from a script.
|
||||
|
||||
Example::
|
||||
|
||||
harness = build_scripted_harness(
|
||||
turns=[
|
||||
ScriptedTurn(tool_calls=[{"name": "echo", "args": {"x": 1}}]),
|
||||
ScriptedTurn(text="done"),
|
||||
],
|
||||
tools=[
|
||||
StubToolSpec(name="echo", description="echo args", handler=lambda **kw: kw),
|
||||
],
|
||||
)
|
||||
"""
|
||||
messages = build_scripted_messages(turns)
|
||||
model = _ToolBindingFakeChatModel(responses=messages, sleep=sleep)
|
||||
realized_tools = [t.build() for t in (tools or [])]
|
||||
return ScriptedHarness(model=model, tools=realized_tools)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ScriptedHarness",
|
||||
"ScriptedTurn",
|
||||
"StubToolSpec",
|
||||
"build_scripted_harness",
|
||||
"build_scripted_messages",
|
||||
]
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Smoke test: scripted harness drives create_agent end-to-end and produces a tool-call-then-final-text trace."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from tests.integration.harness import (
|
||||
ScriptedTurn,
|
||||
StubToolSpec,
|
||||
build_scripted_harness,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scripted_harness_drives_basic_agent() -> None:
|
||||
harness = build_scripted_harness(
|
||||
turns=[
|
||||
ScriptedTurn(
|
||||
tool_calls=[
|
||||
{"name": "echo", "args": {"x": 1}, "id": "call_1"},
|
||||
]
|
||||
),
|
||||
ScriptedTurn(text="done"),
|
||||
],
|
||||
tools=[
|
||||
StubToolSpec(
|
||||
name="echo",
|
||||
description="Echo args back.",
|
||||
handler=lambda **kwargs: {"echoed": kwargs},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
harness.model,
|
||||
system_prompt="You are a test agent.",
|
||||
tools=harness.tools,
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [("user", "do the thing")]})
|
||||
messages = result["messages"]
|
||||
final_ai = next(
|
||||
(m for m in reversed(messages) if m.__class__.__name__ == "AIMessage"),
|
||||
None,
|
||||
)
|
||||
assert final_ai is not None
|
||||
assert final_ai.content == "done"
|
||||
tool_messages = [m for m in messages if m.__class__.__name__ == "ToolMessage"]
|
||||
assert len(tool_messages) == 1
|
||||
assert "echoed" in str(tool_messages[0].content)
|
||||
1
surfsense_backend/tests/unit/agents/__init__.py
Normal file
1
surfsense_backend/tests/unit/agents/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
1
surfsense_backend/tests/unit/agents/new_chat/__init__.py
Normal file
1
surfsense_backend/tests/unit/agents/new_chat/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""__init__ stub so pytest discovers the prompts test module."""
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
"""Tests for the prompt fragment composer (Tier 3a)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.prompts.composer import (
|
||||
ALL_TOOL_NAMES_ORDERED,
|
||||
compose_system_prompt,
|
||||
detect_provider_variant,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_today() -> datetime:
|
||||
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
class TestProviderVariantDetection:
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected",
|
||||
[
|
||||
("openai:gpt-4o-mini", "openai_classic"),
|
||||
("openai:gpt-4-turbo", "openai_classic"),
|
||||
("openai:gpt-5", "openai_reasoning"),
|
||||
("openai:gpt-5-codex", "openai_reasoning"),
|
||||
("openai:o1-preview", "openai_reasoning"),
|
||||
("openai:o3-mini", "openai_reasoning"),
|
||||
("anthropic:claude-3-5-sonnet", "anthropic"),
|
||||
("anthropic/claude-opus-4", "anthropic"),
|
||||
("google:gemini-2.0-flash", "google"),
|
||||
("vertex:gemini-1.5-pro", "google"),
|
||||
("groq:mixtral-8x7b", "default"),
|
||||
(None, "default"),
|
||||
("", "default"),
|
||||
],
|
||||
)
|
||||
def test_detection(self, model_name: str | None, expected: str) -> None:
|
||||
assert detect_provider_variant(model_name) == expected
|
||||
|
||||
|
||||
class TestCompose:
|
||||
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today)
|
||||
# System instruction wrapper
|
||||
assert "<system_instruction>" in prompt
|
||||
assert "</system_instruction>" in prompt
|
||||
# Date interpolated
|
||||
assert "2025-06-01" in prompt
|
||||
# Core policy blocks present
|
||||
assert "<knowledge_base_only_policy>" in prompt
|
||||
assert "<tool_routing>" in prompt
|
||||
assert "<parameter_resolution>" in prompt
|
||||
assert "<memory_protocol>" in prompt
|
||||
# Tools
|
||||
assert "<tools>" in prompt
|
||||
assert "</tools>" in prompt
|
||||
# Citations on by default
|
||||
assert "<citation_instructions>" in prompt
|
||||
assert "[citation:chunk_id]" in prompt
|
||||
|
||||
def test_team_visibility_uses_team_variants(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
thread_visibility=ChatVisibility.SEARCH_SPACE,
|
||||
)
|
||||
# Team-specific phrasing in the agent block
|
||||
assert "team space" in prompt
|
||||
# Memory protocol mentions team
|
||||
assert "team" in prompt
|
||||
# Should NOT mention the user-only memory phrasing
|
||||
assert "personal knowledge base" not in prompt
|
||||
|
||||
def test_private_visibility_uses_private_variants(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
assert "personal knowledge base" in prompt
|
||||
# Should NOT mention the team-specific phrasing about prefixed authors
|
||||
assert "[DisplayName of the author]" not in prompt
|
||||
|
||||
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
|
||||
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
|
||||
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
|
||||
assert "Citations are DISABLED" in prompt_off
|
||||
assert "Citations are DISABLED" not in prompt_on
|
||||
assert "[citation:chunk_id]" in prompt_on
|
||||
|
||||
def test_enabled_tool_filter_only_includes_listed_tools(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search", "scrape_webpage"},
|
||||
)
|
||||
assert "web_search:" in prompt or "- web_search:" in prompt
|
||||
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
|
||||
# Excluded tools should NOT appear in tool listing
|
||||
assert "generate_podcast:" not in prompt
|
||||
assert "generate_image:" not in prompt
|
||||
|
||||
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search"},
|
||||
disabled_tool_names={"generate_image", "generate_podcast"},
|
||||
)
|
||||
assert "DISABLED TOOLS (by user):" in prompt
|
||||
assert "Generate Image" in prompt
|
||||
assert "Generate Podcast" in prompt
|
||||
|
||||
def test_mcp_routing_block_emits_when_provided(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
|
||||
)
|
||||
assert "<mcp_tool_routing>" in prompt
|
||||
assert "My GitLab" in prompt
|
||||
assert "gitlab_search" in prompt
|
||||
|
||||
def test_mcp_routing_block_absent_when_no_servers(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
|
||||
assert "<mcp_tool_routing>" not in prompt
|
||||
|
||||
def test_provider_block_renders_when_anthropic(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
|
||||
)
|
||||
assert "<provider_hints>" in prompt
|
||||
assert "Anthropic" in prompt or "Claude" in prompt
|
||||
|
||||
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
|
||||
assert "<provider_hints>" not in prompt
|
||||
|
||||
def test_custom_system_instructions_override_default(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
custom = "You are a custom assistant. Today is {resolved_today}."
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today, custom_system_instructions=custom
|
||||
)
|
||||
assert "You are a custom assistant. Today is 2025-06-01." in prompt
|
||||
# Default block should NOT be present
|
||||
assert "<knowledge_base_only_policy>" not in prompt
|
||||
|
||||
def test_use_default_false_with_no_custom_yields_no_system_block(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
use_default_system_instructions=False,
|
||||
)
|
||||
# No system_instruction wrapper but tools/citations still emitted
|
||||
assert "<system_instruction>" not in prompt
|
||||
assert "<tools>" in prompt
|
||||
|
||||
def test_all_known_tools_have_fragments(self) -> None:
|
||||
# Soft assertion: verify that every tool in the canonical order
|
||||
# produces non-empty content for at least one variant.
|
||||
for tool in ALL_TOOL_NAMES_ORDERED:
|
||||
prompt = compose_system_prompt(
|
||||
today=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
enabled_tool_names={tool},
|
||||
)
|
||||
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
|
||||
|
||||
|
||||
class TestStableOrderingForCacheStability:
|
||||
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
|
||||
|
||||
def test_composition_is_deterministic_given_same_inputs(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
a = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search", "scrape_webpage"},
|
||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||
)
|
||||
b = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter
|
||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||
)
|
||||
assert a == b
|
||||
311
surfsense_backend/tests/unit/agents/new_chat/test_action_log.py
Normal file
311
surfsense_backend/tests/unit/agents/new_chat/test_action_log.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Unit tests for ActionLogMiddleware (Tier 5.2)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRequest:
|
||||
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
||||
|
||||
tool_call: dict[str, Any]
|
||||
tool: Any = None
|
||||
state: Any = None
|
||||
runtime: Any = None
|
||||
|
||||
|
||||
@tool
|
||||
def make_widget(color: str, size: int) -> str:
|
||||
"""Create a widget."""
|
||||
return f"made {color} {size}"
|
||||
|
||||
|
||||
def _enabled_flags(**overrides: bool) -> AgentFeatureFlags:
|
||||
return AgentFeatureFlags(
|
||||
disable_new_agent_stack=False,
|
||||
enable_action_log=True,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def _disabled_flags() -> AgentFeatureFlags:
|
||||
return AgentFeatureFlags(disable_new_agent_stack=False, enable_action_log=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_get_flags():
|
||||
def _patch(flags: AgentFeatureFlags):
|
||||
return patch(
|
||||
"app.agents.new_chat.middleware.action_log.get_flags",
|
||||
return_value=flags,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session_factory():
|
||||
"""Patch ``shielded_async_session`` with a recording fake."""
|
||||
captured: dict[str, list] = {"rows": []}
|
||||
|
||||
class _FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def add(self, row):
|
||||
captured["rows"].append(row)
|
||||
|
||||
async def commit(self):
|
||||
captured["committed"] = True
|
||||
|
||||
def _factory():
|
||||
return _FakeSession()
|
||||
|
||||
return captured, _factory
|
||||
|
||||
|
||||
class TestActionLogMiddlewareDisabled:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_op_when_flag_off(self, patch_get_flags) -> None:
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
with patch_get_flags(_disabled_flags()):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
handler.assert_awaited_once()
|
||||
assert isinstance(result, ToolMessage)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_op_when_thread_id_none(self, patch_get_flags) -> None:
|
||||
mw = ActionLogMiddleware(thread_id=None, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
with patch_get_flags(_enabled_flags()):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert isinstance(result, ToolMessage)
|
||||
|
||||
|
||||
class TestActionLogMiddlewarePersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_row_on_success(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||
request = _FakeRequest(
|
||||
tool_call={
|
||||
"name": "make_widget",
|
||||
"args": {"color": "red", "size": 3},
|
||||
"id": "tc-abc",
|
||||
},
|
||||
)
|
||||
result_msg = ToolMessage(
|
||||
content="ok", tool_call_id="tc-abc", id="msg-1"
|
||||
)
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
|
||||
assert result is result_msg
|
||||
assert len(captured["rows"]) == 1
|
||||
row = captured["rows"][0]
|
||||
assert row.thread_id == 42
|
||||
assert row.search_space_id == 7
|
||||
assert row.user_id == "u1"
|
||||
assert row.tool_name == "make_widget"
|
||||
assert row.args == {"color": "red", "size": 3}
|
||||
assert row.result_id == "msg-1"
|
||||
assert row.error is None
|
||||
assert row.reverse_descriptor is None
|
||||
assert row.reversible is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_row_on_failure_and_reraises(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {"color": "red"}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(side_effect=ValueError("boom"))
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
), pytest.raises(ValueError, match="boom"):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
assert len(captured["rows"]) == 1
|
||||
row = captured["rows"][0]
|
||||
assert row.tool_name == "make_widget"
|
||||
assert row.error == {"type": "ValueError", "message": "boom"}
|
||||
assert row.result_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_failure_does_not_break_tool_call(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
"""Even if the DB write blows up, the tool's result must reach the model."""
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
def _exploding_session():
|
||||
raise RuntimeError("DB is down")
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=_exploding_session
|
||||
):
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert result is result_msg
|
||||
|
||||
|
||||
class TestReverseDescriptor:
|
||||
@pytest.mark.asyncio
|
||||
async def test_renders_reverse_descriptor_when_tool_declares_one(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
|
||||
def _reverse(args, result):
|
||||
return {"tool": "delete_widget", "args": {"id": result["id"]}}
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
name="make_widget",
|
||||
description="Create a widget",
|
||||
factory=lambda deps: make_widget,
|
||||
reverse=_reverse,
|
||||
)
|
||||
mw = ActionLogMiddleware(
|
||||
thread_id=1,
|
||||
search_space_id=1,
|
||||
user_id="u",
|
||||
tool_definitions={"make_widget": tool_def},
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tool_call={
|
||||
"name": "make_widget",
|
||||
"args": {"color": "blue", "size": 1},
|
||||
"id": "tc-xyz",
|
||||
},
|
||||
)
|
||||
result_msg = ToolMessage(
|
||||
content='{"id": "widget-9"}', tool_call_id="tc-xyz", id="msg-9"
|
||||
)
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
row = captured["rows"][0]
|
||||
assert row.reversible is True
|
||||
assert row.reverse_descriptor == {
|
||||
"tool": "delete_widget",
|
||||
"args": {"id": "widget-9"},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_reverse_callable_errors(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
|
||||
def _bad_reverse(args, result):
|
||||
raise RuntimeError("reverse blew up")
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
name="make_widget",
|
||||
description="Create a widget",
|
||||
factory=lambda deps: make_widget,
|
||||
reverse=_bad_reverse,
|
||||
)
|
||||
mw = ActionLogMiddleware(
|
||||
thread_id=1,
|
||||
search_space_id=1,
|
||||
user_id=None,
|
||||
tool_definitions={"make_widget": tool_def},
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
row = captured["rows"][0]
|
||||
assert row.reversible is False
|
||||
assert row.reverse_descriptor is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reverse_when_tool_definition_missing(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(
|
||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
||||
)
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
assert row.reversible is False
|
||||
|
||||
|
||||
class TestArgsTruncation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_huge_args_payload_is_truncated(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
# Build a > 32KB string so the persisted payload triggers the truncation path.
|
||||
huge = "x" * (40 * 1024)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"},
|
||||
)
|
||||
handler = AsyncMock(
|
||||
return_value=ToolMessage(content="ok", tool_call_id="tc1")
|
||||
)
|
||||
with patch_get_flags(_enabled_flags()), patch(
|
||||
"app.db.shielded_async_session", side_effect=lambda: factory()
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
assert row.args is not None
|
||||
assert row.args.get("_truncated") is True
|
||||
assert row.args.get("_size", 0) >= 40 * 1024
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
"""Tests for BusyMutexMiddleware: per-thread lock + cancel event behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
BusyMutexMiddleware,
|
||||
get_cancel_event,
|
||||
manager,
|
||||
request_cancel,
|
||||
reset_cancel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Runtime:
|
||||
def __init__(self, thread_id: str | None) -> None:
|
||||
self.config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_acquire_succeeds_and_release_unblocks() -> None:
|
||||
mw = BusyMutexMiddleware()
|
||||
runtime = _Runtime("t1")
|
||||
await mw.abefore_agent({}, runtime)
|
||||
|
||||
# Lock should now be held
|
||||
lock = manager.lock_for("t1")
|
||||
assert lock.locked()
|
||||
|
||||
await mw.aafter_agent({}, runtime)
|
||||
assert not lock.locked()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_concurrent_acquire_raises_busy() -> None:
|
||||
mw_a = BusyMutexMiddleware()
|
||||
mw_b = BusyMutexMiddleware()
|
||||
runtime = _Runtime("t-conflict")
|
||||
await mw_a.abefore_agent({}, runtime)
|
||||
|
||||
with pytest.raises(BusyError) as excinfo:
|
||||
await mw_b.abefore_agent({}, runtime)
|
||||
assert excinfo.value.request_id == "t-conflict"
|
||||
|
||||
await mw_a.aafter_agent({}, runtime)
|
||||
# After release, mw_b can acquire
|
||||
await mw_b.abefore_agent({}, runtime)
|
||||
await mw_b.aafter_agent({}, runtime)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_event_lifecycle() -> None:
|
||||
mw = BusyMutexMiddleware()
|
||||
runtime = _Runtime("t-cancel")
|
||||
|
||||
await mw.abefore_agent({}, runtime)
|
||||
event = get_cancel_event("t-cancel")
|
||||
assert not event.is_set()
|
||||
|
||||
request_cancel("t-cancel")
|
||||
assert event.is_set()
|
||||
|
||||
# End of turn should reset
|
||||
await mw.aafter_agent({}, runtime)
|
||||
assert not event.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_raises_when_required() -> None:
|
||||
mw = BusyMutexMiddleware(require_thread_id=True)
|
||||
runtime = _Runtime(None)
|
||||
with pytest.raises(BusyError):
|
||||
await mw.abefore_agent({}, runtime)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_skipped_when_not_required() -> None:
|
||||
mw = BusyMutexMiddleware(require_thread_id=False)
|
||||
runtime = _Runtime(None)
|
||||
await mw.abefore_agent({}, runtime)
|
||||
await mw.aafter_agent({}, runtime)
|
||||
|
||||
|
||||
def test_reset_cancel_idempotent() -> None:
|
||||
# Should not raise even if event was never created
|
||||
reset_cancel("never-seen")
|
||||
107
surfsense_backend/tests/unit/agents/new_chat/test_compaction.py
Normal file
107
surfsense_backend/tests/unit/agents/new_chat/test_compaction.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for SurfSenseCompactionMiddleware: protected SystemMessage handling and content sanitization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from app.agents.new_chat.middleware.compaction import (
|
||||
PROTECTED_SYSTEM_PREFIXES,
|
||||
_is_protected_system_message,
|
||||
_sanitize_message_content,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestIsProtectedSystemMessage:
|
||||
@pytest.mark.parametrize("prefix", PROTECTED_SYSTEM_PREFIXES)
|
||||
def test_each_prefix_protected(self, prefix: str) -> None:
|
||||
msg = SystemMessage(content=f"{prefix}\nbody\n</close>")
|
||||
assert _is_protected_system_message(msg) is True
|
||||
|
||||
def test_unprotected_system_message(self) -> None:
|
||||
assert _is_protected_system_message(SystemMessage(content="random instructions")) is False
|
||||
|
||||
def test_human_message_never_protected(self) -> None:
|
||||
assert _is_protected_system_message(HumanMessage(content="<workspace_tree>...")) is False
|
||||
|
||||
def test_tolerates_leading_whitespace(self) -> None:
|
||||
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
||||
assert _is_protected_system_message(msg) is True
|
||||
|
||||
|
||||
class TestSanitizeMessageContent:
|
||||
def test_returns_same_message_when_content_present(self) -> None:
|
||||
msg = AIMessage(content="hello")
|
||||
assert _sanitize_message_content(msg) is msg
|
||||
|
||||
def test_replaces_none_with_empty_string(self) -> None:
|
||||
# Pydantic blocks ``content=None`` at construction; the real
|
||||
# crash happens when the streaming layer mutates ``content``
|
||||
# after-the-fact. Replicate that by force-setting on a built
|
||||
# message.
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "x", "args": {}, "id": "1"}],
|
||||
)
|
||||
# Bypass pydantic validation to simulate the LiteLLM/Bedrock case
|
||||
object.__setattr__(msg, "content", None)
|
||||
sanitized = _sanitize_message_content(msg)
|
||||
assert sanitized.content == ""
|
||||
|
||||
|
||||
class TestPartitionMessages:
|
||||
"""Verify the partition override surfaces protected SystemMessages
|
||||
into ``preserved_messages`` regardless of cutoff position.
|
||||
"""
|
||||
|
||||
def _build_partitioner(self):
|
||||
# Construct a thin shim — we can't easily instantiate the full
|
||||
# SurfSenseCompactionMiddleware without a real model, but the
|
||||
# override path needs ``_lc_helper`` to delegate to. We mock
|
||||
# that with a simple slicing partitioner equivalent to the real one.
|
||||
from app.agents.new_chat.middleware.compaction import (
|
||||
SurfSenseCompactionMiddleware,
|
||||
)
|
||||
|
||||
class _LcHelper:
|
||||
@staticmethod
|
||||
def _partition_messages(messages, cutoff):
|
||||
return messages[:cutoff], messages[cutoff:]
|
||||
|
||||
class _Stub(SurfSenseCompactionMiddleware):
|
||||
def __init__(self):
|
||||
self._lc_helper = _LcHelper()
|
||||
|
||||
return _Stub()
|
||||
|
||||
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
|
||||
partitioner = self._build_partitioner()
|
||||
protected = SystemMessage(content="<priority_documents>\n...")
|
||||
msgs = [
|
||||
HumanMessage(content="old human"),
|
||||
AIMessage(content="old ai"),
|
||||
protected,
|
||||
ToolMessage(content="tool 1", tool_call_id="t1"),
|
||||
HumanMessage(content="new"),
|
||||
]
|
||||
# Cutoff = 4 means everything before index 4 should be summarized
|
||||
to_summary, preserved = partitioner._partition_messages(msgs, 4)
|
||||
|
||||
assert protected not in to_summary
|
||||
assert protected in preserved
|
||||
# The non-protected old messages remain in to_summary
|
||||
assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary)
|
||||
|
||||
def test_unprotected_messages_unaffected(self) -> None:
|
||||
partitioner = self._build_partitioner()
|
||||
msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")]
|
||||
to_summary, preserved = partitioner._partition_messages(msgs, 2)
|
||||
assert [m.content for m in to_summary] == ["a", "b"]
|
||||
assert [m.content for m in preserved] == ["c"]
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for SpillToBackendEdit and SpillingContextEditingMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.middleware.context_editing import (
|
||||
SpillToBackendEdit,
|
||||
_build_spill_placeholder,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _build_history(num_pairs: int = 6) -> list[Any]:
|
||||
"""Build a long history of (AIMessage with tool_call, ToolMessage) pairs."""
|
||||
msgs: list[Any] = [HumanMessage(content="please do many things")]
|
||||
for i in range(num_pairs):
|
||||
msgs.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": f"tool_{i}", "args": {"i": i}, "id": f"call-{i}"},
|
||||
],
|
||||
)
|
||||
)
|
||||
msgs.append(
|
||||
ToolMessage(
|
||||
content="x" * 5000,
|
||||
tool_call_id=f"call-{i}",
|
||||
name=f"tool_{i}",
|
||||
id=f"tool-msg-{i}",
|
||||
)
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _approx_count(messages: list[Any]) -> int:
|
||||
"""Trivial token counter: 1 token per 4 chars."""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str):
|
||||
total += len(content) // 4
|
||||
return total
|
||||
|
||||
|
||||
class TestSpillEdit:
|
||||
def test_below_trigger_does_nothing(self) -> None:
|
||||
edit = SpillToBackendEdit(trigger=1_000_000, keep=2)
|
||||
msgs = _build_history(3)
|
||||
original_lengths = [len(getattr(m, "content", "")) for m in msgs]
|
||||
edit.apply(msgs, count_tokens=_approx_count)
|
||||
new_lengths = [len(getattr(m, "content", "")) for m in msgs]
|
||||
assert original_lengths == new_lengths
|
||||
assert edit.pending_spills == []
|
||||
|
||||
def test_above_trigger_clears_and_records(self) -> None:
|
||||
edit = SpillToBackendEdit(trigger=100, keep=1, path_prefix="/tool_outputs")
|
||||
msgs = _build_history(4)
|
||||
edit.apply(msgs, count_tokens=_approx_count)
|
||||
|
||||
# The most-recent ToolMessage (keep=1) should remain intact
|
||||
tool_messages = [m for m in msgs if isinstance(m, ToolMessage)]
|
||||
intact = tool_messages[-1]
|
||||
assert intact.content.startswith("x") # untouched
|
||||
|
||||
# Earlier ToolMessages should now contain the placeholder text
|
||||
cleared = [
|
||||
m for m in tool_messages
|
||||
if isinstance(m.content, str) and m.content.startswith("[cleared")
|
||||
]
|
||||
assert len(cleared) >= 1
|
||||
# And the spill list should match
|
||||
assert len(edit.pending_spills) == len(cleared)
|
||||
|
||||
def test_excluded_tools_not_cleared(self) -> None:
|
||||
edit = SpillToBackendEdit(
|
||||
trigger=100,
|
||||
keep=0,
|
||||
exclude_tools=("tool_0",),
|
||||
)
|
||||
msgs = _build_history(4)
|
||||
edit.apply(msgs, count_tokens=_approx_count)
|
||||
|
||||
first_tool = next(
|
||||
m for m in msgs if isinstance(m, ToolMessage) and m.name == "tool_0"
|
||||
)
|
||||
# Excluded — untouched
|
||||
assert first_tool.content.startswith("x")
|
||||
|
||||
def test_drain_clears_pending(self) -> None:
|
||||
edit = SpillToBackendEdit(trigger=100, keep=1)
|
||||
msgs = _build_history(4)
|
||||
edit.apply(msgs, count_tokens=_approx_count)
|
||||
first_drain = edit.drain_pending()
|
||||
assert len(first_drain) > 0
|
||||
assert edit.drain_pending() == []
|
||||
|
||||
def test_placeholder_format(self) -> None:
|
||||
path = "/tool_outputs/thread-1/tool-msg-0.txt"
|
||||
text = _build_spill_placeholder(path)
|
||||
assert path in text
|
||||
assert "explore" in text # mentions the recovery agent
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""Tests for declarative dedup_key on ToolDefinition (Tier 2.3 migration)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_tool(name: str, *, dedup_key=None, hitl_dedup_key=None):
|
||||
metadata = {}
|
||||
if dedup_key is not None:
|
||||
metadata["dedup_key"] = dedup_key
|
||||
if hitl_dedup_key is not None:
|
||||
metadata["hitl"] = True
|
||||
metadata["hitl_dedup_key"] = hitl_dedup_key
|
||||
|
||||
def _fn(**kwargs):
|
||||
return "ok"
|
||||
|
||||
return StructuredTool.from_function(
|
||||
func=_fn, name=name, description="x", metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
def _msg(*calls: dict) -> AIMessage:
|
||||
return AIMessage(content="", tool_calls=list(calls))
|
||||
|
||||
|
||||
class _Runtime:
|
||||
pass
|
||||
|
||||
|
||||
def test_callable_dedup_key_takes_priority() -> None:
|
||||
tool = _make_tool(
|
||||
"create_doc",
|
||||
dedup_key=lambda args: f"{args.get('parent_id')}::{args.get('title')}",
|
||||
)
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"},
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"},
|
||||
{"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is not None
|
||||
new_calls = out["messages"][0].tool_calls
|
||||
assert len(new_calls) == 2 # one duplicate dropped
|
||||
assert {c["id"] for c in new_calls} == {"1", "3"}
|
||||
|
||||
|
||||
def test_string_hitl_dedup_key_still_works() -> None:
|
||||
tool = _make_tool("send_x", hitl_dedup_key="subject")
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "send_x", "args": {"subject": "Hello"}, "id": "1"},
|
||||
{"name": "send_x", "args": {"subject": "hello"}, "id": "2"}, # case
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is not None
|
||||
assert len(out["messages"][0].tool_calls) == 1
|
||||
|
||||
|
||||
def test_no_agent_tools_means_no_dedup() -> None:
|
||||
"""After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS``
|
||||
map, dedup is purely declarative — no resolvers means no dedup runs.
|
||||
|
||||
Coverage for the previously hardcoded native HITL tools now lives on
|
||||
each :class:`ToolDefinition.dedup_key` in
|
||||
:mod:`app.agents.new_chat.tools.registry`, which is wired through to
|
||||
``tool.metadata`` by :func:`build_tools`.
|
||||
"""
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "create_notion_page", "args": {"title": "X"}, "id": "1"},
|
||||
{"name": "create_notion_page", "args": {"title": "x"}, "id": "2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
|
||||
"""Smoke-check the wiring path that replaced the legacy native map.
|
||||
|
||||
``ToolDefinition.dedup_key`` set in the registry must be copied onto
|
||||
the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware`
|
||||
can pick it up at agent build time.
|
||||
"""
|
||||
from app.agents.new_chat.tools.registry import (
|
||||
BUILTIN_TOOLS,
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
|
||||
notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"]
|
||||
assert notion_tool_defs, "registry should still expose create_notion_page"
|
||||
tool_def = notion_tool_defs[0]
|
||||
assert tool_def.dedup_key is not None
|
||||
# Same wrapping helper used in the registry — sanity check identity
|
||||
sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"})
|
||||
assert sample == "plan"
|
||||
|
||||
|
||||
def test_unknown_tool_passes_through() -> None:
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "anything_else", "args": {"x": 1}, "id": "1"},
|
||||
{"name": "anything_else", "args": {"x": 1}, "id": "2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is None # no dedup configured -> kept
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""Lock in the default-allow layering used by ``chat_deepagent``.
|
||||
|
||||
The agent factory wires ``PermissionMiddleware`` with three rulesets,
|
||||
earliest -> latest:
|
||||
|
||||
1. ``surfsense_defaults`` (single ``allow */*`` rule)
|
||||
2. ``connector_synthesized`` (deny rules for tools whose required
|
||||
connector is missing)
|
||||
3. (future) user-defined rules from the Agent Permissions UI
|
||||
|
||||
Without #1 every read-only built-in (``ls``, ``read_file``, ``grep``,
|
||||
``glob``, ``web_search`` …) defaulted to ``ask`` because
|
||||
``permissions.evaluate`` returns ``ask`` when no rule matches. That
|
||||
caused two production-painful behaviors:
|
||||
|
||||
* Resume payloads with a prior reject decision bled into innocent
|
||||
read-only tool calls, raising ``RejectedError("ls")``.
|
||||
* Mutating connector tools got *double* prompted — once via the
|
||||
middleware ``ask`` and again via the per-tool ``interrupt()`` in
|
||||
``app.agents.new_chat.tools.hitl``.
|
||||
|
||||
These tests pin the layering so a refactor that drops the default
|
||||
ruleset fails loud.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _layered_rulesets(connector_denies: list[Rule]) -> list[Ruleset]:
|
||||
"""Replicate ``chat_deepagent`` layering for the test."""
|
||||
return [
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
Ruleset(rules=connector_denies, origin="connector_synthesized"),
|
||||
]
|
||||
|
||||
|
||||
class TestReadOnlyToolsAllowed:
|
||||
"""Read-only built-ins must NOT default to ask."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
[
|
||||
"ls",
|
||||
"read_file",
|
||||
"grep",
|
||||
"glob",
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"search_surfsense_docs",
|
||||
"get_connected_accounts",
|
||||
"write_todos",
|
||||
"task",
|
||||
"_noop",
|
||||
"invalid",
|
||||
"update_memory",
|
||||
],
|
||||
)
|
||||
def test_default_allow_covers_safe_builtin(self, tool_name: str) -> None:
|
||||
rulesets = _layered_rulesets(connector_denies=[])
|
||||
rules = evaluate_many(tool_name, [tool_name], *rulesets)
|
||||
assert aggregate_action(rules) == "allow"
|
||||
|
||||
|
||||
class TestConnectorDenyOverridesDefaultAllow:
|
||||
"""Connector-synthesized denies must beat the default-allow rule."""
|
||||
|
||||
def test_missing_connector_tool_is_denied(self) -> None:
|
||||
rulesets = _layered_rulesets(
|
||||
connector_denies=[
|
||||
Rule(permission="linear_create_issue", pattern="*", action="deny")
|
||||
]
|
||||
)
|
||||
rules = evaluate_many(
|
||||
"linear_create_issue", ["linear_create_issue"], *rulesets
|
||||
)
|
||||
assert aggregate_action(rules) == "deny"
|
||||
|
||||
def test_default_allow_still_applies_to_other_tools(self) -> None:
|
||||
"""A deny rule for one tool must not bleed onto unrelated calls."""
|
||||
rulesets = _layered_rulesets(
|
||||
connector_denies=[
|
||||
Rule(permission="linear_create_issue", pattern="*", action="deny")
|
||||
]
|
||||
)
|
||||
rules = evaluate_many("ls", ["ls"], *rulesets)
|
||||
assert aggregate_action(rules) == "allow"
|
||||
|
||||
|
||||
class TestUserRuleOverridesDefault:
|
||||
"""User rules layered last must override the default-allow rule."""
|
||||
|
||||
def test_user_ask_overrides_default_allow(self) -> None:
|
||||
defaults = Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
user_ruleset = Ruleset(
|
||||
rules=[Rule(permission="ls", pattern="*", action="ask")],
|
||||
origin="user",
|
||||
)
|
||||
rules = evaluate_many("ls", ["ls"], defaults, user_ruleset)
|
||||
assert aggregate_action(rules) == "ask"
|
||||
|
||||
def test_user_deny_overrides_default_allow(self) -> None:
|
||||
defaults = Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
user_ruleset = Ruleset(
|
||||
rules=[Rule(permission="send_*", pattern="*", action="deny")],
|
||||
origin="user",
|
||||
)
|
||||
rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset)
|
||||
assert aggregate_action(rules) == "deny"
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""Tests for DoomLoopMiddleware signature equality detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_signature_is_stable_for_identical_args() -> None:
|
||||
a = _signature("search", {"q": "hello", "n": 10})
|
||||
b = _signature("search", {"n": 10, "q": "hello"})
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_signature_changes_with_args() -> None:
|
||||
a = _signature("search", {"q": "hello"})
|
||||
b = _signature("search", {"q": "world"})
|
||||
assert a != b
|
||||
|
||||
|
||||
def test_signature_changes_with_name() -> None:
|
||||
a = _signature("search", {"q": "x"})
|
||||
b = _signature("read", {"q": "x"})
|
||||
assert a != b
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, thread_id: str | None = "thread-1") -> None:
|
||||
self.config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
|
||||
def _msg_calling(name: str, args: dict, call_id: str) -> AIMessage:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": name, "args": args, "id": call_id}],
|
||||
)
|
||||
|
||||
|
||||
def test_threshold_triggers_after_n_identical_calls() -> None:
|
||||
mw = DoomLoopMiddleware(threshold=3)
|
||||
runtime = _FakeRuntime()
|
||||
|
||||
# First two calls — under threshold
|
||||
for i in range(2):
|
||||
out = mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 1}, f"call-{i}")]},
|
||||
runtime,
|
||||
)
|
||||
assert out is None
|
||||
|
||||
# Third identical call should trigger ``langgraph.types.interrupt``.
|
||||
# In a unit-test context (no runnable graph), ``interrupt`` raises
|
||||
# ``RuntimeError`` because ``get_config`` has nothing to bind to —
|
||||
# we accept that as proof the interrupt path was taken (the
|
||||
# alternative would be no exception, which would mean the loop
|
||||
# detection never fired).
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 1}, "call-3")]},
|
||||
runtime,
|
||||
)
|
||||
name = type(excinfo.value).__name__.lower()
|
||||
assert (
|
||||
"interrupt" in name
|
||||
or "runtimeerror" in name
|
||||
), f"Expected an interrupt-style exception, got {name}"
|
||||
|
||||
|
||||
def test_does_not_trigger_when_args_differ() -> None:
|
||||
mw = DoomLoopMiddleware(threshold=2)
|
||||
runtime = _FakeRuntime()
|
||||
out = mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime
|
||||
)
|
||||
assert out is None
|
||||
out = mw.after_model(
|
||||
{"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_separate_threads_have_independent_windows() -> None:
|
||||
mw = DoomLoopMiddleware(threshold=2)
|
||||
rt_a = _FakeRuntime(thread_id="A")
|
||||
rt_b = _FakeRuntime(thread_id="B")
|
||||
|
||||
mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_a)
|
||||
# thread B should NOT count thread A's call
|
||||
out = mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_b)
|
||||
assert out is None # not yet at threshold for B
|
||||
|
||||
|
||||
def test_invalid_threshold_rejected() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
DoomLoopMiddleware(threshold=1)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for the agent feature-flag system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.feature_flags import (
|
||||
AgentFeatureFlags,
|
||||
reload_for_tests,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for name in [
|
||||
"SURFSENSE_DISABLE_NEW_AGENT_STACK",
|
||||
"SURFSENSE_ENABLE_CONTEXT_EDITING",
|
||||
"SURFSENSE_ENABLE_COMPACTION_V2",
|
||||
"SURFSENSE_ENABLE_RETRY_AFTER",
|
||||
"SURFSENSE_ENABLE_MODEL_FALLBACK",
|
||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
|
||||
"SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
|
||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
|
||||
"SURFSENSE_ENABLE_DOOM_LOOP",
|
||||
"SURFSENSE_ENABLE_PERMISSION",
|
||||
"SURFSENSE_ENABLE_BUSY_MUTEX",
|
||||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"SURFSENSE_ENABLE_SKILLS",
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
]:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
flags = reload_for_tests()
|
||||
assert isinstance(flags, AgentFeatureFlags)
|
||||
assert flags.disable_new_agent_stack is False
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
|
||||
|
||||
def test_master_kill_switch_overrides_individual_flags(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv("SURFSENSE_DISABLE_NEW_AGENT_STACK", "true")
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_CONTEXT_EDITING", "true")
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_PERMISSION", "true")
|
||||
|
||||
flags = reload_for_tests()
|
||||
assert flags.disable_new_agent_stack is True
|
||||
assert flags.enable_context_editing is False
|
||||
assert flags.enable_permission is False
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truthy", ["1", "true", "TRUE", "yes", "on"])
|
||||
def test_individual_flags_truthy_values(
|
||||
monkeypatch: pytest.MonkeyPatch, truthy: str
|
||||
) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", truthy)
|
||||
flags = reload_for_tests()
|
||||
assert flags.enable_retry_after is True
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("falsy", ["0", "false", "no", "off", "", "garbage"])
|
||||
def test_individual_flags_falsy_values(
|
||||
monkeypatch: pytest.MonkeyPatch, falsy: str
|
||||
) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", falsy)
|
||||
flags = reload_for_tests()
|
||||
assert flags.enable_retry_after is False
|
||||
|
||||
|
||||
def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
flag_to_env = {
|
||||
"enable_context_editing": "SURFSENSE_ENABLE_CONTEXT_EDITING",
|
||||
"enable_compaction_v2": "SURFSENSE_ENABLE_COMPACTION_V2",
|
||||
"enable_retry_after": "SURFSENSE_ENABLE_RETRY_AFTER",
|
||||
"enable_model_fallback": "SURFSENSE_ENABLE_MODEL_FALLBACK",
|
||||
"enable_model_call_limit": "SURFSENSE_ENABLE_MODEL_CALL_LIMIT",
|
||||
"enable_tool_call_limit": "SURFSENSE_ENABLE_TOOL_CALL_LIMIT",
|
||||
"enable_tool_call_repair": "SURFSENSE_ENABLE_TOOL_CALL_REPAIR",
|
||||
"enable_doom_loop": "SURFSENSE_ENABLE_DOOM_LOOP",
|
||||
"enable_permission": "SURFSENSE_ENABLE_PERMISSION",
|
||||
"enable_busy_mutex": "SURFSENSE_ENABLE_BUSY_MUTEX",
|
||||
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
|
||||
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||
}
|
||||
|
||||
# `enable_otel` is intentionally orthogonal — it does NOT count toward
|
||||
# ``any_new_middleware_enabled`` because OTel is observability-only and
|
||||
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
|
||||
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
|
||||
|
||||
for attr, env_name in flag_to_env.items():
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv(env_name, "true")
|
||||
flags = reload_for_tests()
|
||||
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
|
||||
if attr in counts_toward_middleware:
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
else:
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
"""Tests for NoopInjectionMiddleware provider-compat logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.middleware.noop_injection import (
|
||||
NOOP_TOOL_NAME,
|
||||
NoopInjectionMiddleware,
|
||||
_last_ai_has_tool_calls,
|
||||
_provider_needs_noop,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _LiteLLMModel:
|
||||
def _get_ls_params(self):
|
||||
return {"ls_provider": "litellm"}
|
||||
|
||||
|
||||
class _BedrockModel:
|
||||
def _get_ls_params(self):
|
||||
return {"ls_provider": "bedrock"}
|
||||
|
||||
|
||||
class _OpenAIModel:
|
||||
def _get_ls_params(self):
|
||||
return {"ls_provider": "openai"}
|
||||
|
||||
|
||||
class _ChatLiteLLM: # name-only fallback
|
||||
pass
|
||||
|
||||
|
||||
class TestProviderDetection:
|
||||
def test_litellm(self) -> None:
|
||||
assert _provider_needs_noop(_LiteLLMModel()) is True
|
||||
|
||||
def test_bedrock(self) -> None:
|
||||
assert _provider_needs_noop(_BedrockModel()) is True
|
||||
|
||||
def test_openai_does_not_need(self) -> None:
|
||||
assert _provider_needs_noop(_OpenAIModel()) is False
|
||||
|
||||
def test_class_name_fallback(self) -> None:
|
||||
assert _provider_needs_noop(_ChatLiteLLM()) is True
|
||||
|
||||
|
||||
class TestHistoryDetection:
|
||||
def test_last_ai_has_tool_calls(self) -> None:
|
||||
msgs = [
|
||||
HumanMessage(content="hi"),
|
||||
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
|
||||
]
|
||||
assert _last_ai_has_tool_calls(msgs) is True
|
||||
|
||||
def test_last_ai_no_tool_calls(self) -> None:
|
||||
msgs = [
|
||||
HumanMessage(content="hi"),
|
||||
AIMessage(content="hello"),
|
||||
]
|
||||
assert _last_ai_has_tool_calls(msgs) is False
|
||||
|
||||
def test_no_ai_in_history(self) -> None:
|
||||
assert _last_ai_has_tool_calls([HumanMessage(content="hi")]) is False
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, *, tools, messages, model) -> None:
|
||||
self.tools = tools
|
||||
self.messages = messages
|
||||
self.model = model
|
||||
|
||||
def override(self, *, tools):
|
||||
return _FakeRequest(tools=tools, messages=self.messages, model=self.model)
|
||||
|
||||
|
||||
class TestShouldInject:
|
||||
def test_injects_when_all_conditions_met(self) -> None:
|
||||
mw = NoopInjectionMiddleware()
|
||||
msgs = [
|
||||
HumanMessage(content="hi"),
|
||||
AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]),
|
||||
]
|
||||
req = _FakeRequest(tools=[], messages=msgs, model=_LiteLLMModel())
|
||||
assert mw._should_inject(req) is True
|
||||
|
||||
def test_skips_when_tools_present(self) -> None:
|
||||
mw = NoopInjectionMiddleware()
|
||||
req = _FakeRequest(
|
||||
tools=[object()],
|
||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
||||
model=_LiteLLMModel(),
|
||||
)
|
||||
assert mw._should_inject(req) is False
|
||||
|
||||
def test_skips_when_no_history_tool_calls(self) -> None:
|
||||
mw = NoopInjectionMiddleware()
|
||||
req = _FakeRequest(
|
||||
tools=[],
|
||||
messages=[HumanMessage(content="hi")],
|
||||
model=_LiteLLMModel(),
|
||||
)
|
||||
assert mw._should_inject(req) is False
|
||||
|
||||
def test_skips_for_openai(self) -> None:
|
||||
mw = NoopInjectionMiddleware()
|
||||
req = _FakeRequest(
|
||||
tools=[],
|
||||
messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])],
|
||||
model=_OpenAIModel(),
|
||||
)
|
||||
assert mw._should_inject(req) is False
|
||||
|
||||
|
||||
def test_noop_tool_name_is_underscore_noop() -> None:
|
||||
assert NOOP_TOOL_NAME == "_noop"
|
||||
195
surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py
Normal file
195
surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
"""Tests for the OtelSpanMiddleware adapter (Tier 3b)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.middleware.otel_span import (
|
||||
OtelSpanMiddleware,
|
||||
_annotate_model_response,
|
||||
_annotate_tool_result,
|
||||
_resolve_input_size,
|
||||
_resolve_model_attrs,
|
||||
_resolve_tool_name,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_otel(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
|
||||
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
|
||||
from app.observability import otel as ot
|
||||
|
||||
ot.reload_for_tests()
|
||||
yield
|
||||
ot.reload_for_tests()
|
||||
|
||||
|
||||
class TestResolveModelAttrs:
|
||||
def test_extracts_model_name_and_provider(self) -> None:
|
||||
request = MagicMock()
|
||||
request.model = MagicMock(spec=["model_name", "provider"])
|
||||
request.model.model_name = "gpt-4o-mini"
|
||||
request.model.provider = "openai"
|
||||
assert _resolve_model_attrs(request) == ("gpt-4o-mini", "openai")
|
||||
|
||||
def test_handles_missing_model(self) -> None:
|
||||
request = MagicMock()
|
||||
request.model = None
|
||||
assert _resolve_model_attrs(request) == (None, None)
|
||||
|
||||
def test_falls_back_through_attribute_chain(self) -> None:
|
||||
request = MagicMock()
|
||||
request.model = MagicMock(spec=["model_id", "_llm_type"])
|
||||
request.model.model_id = "claude-3-5-sonnet"
|
||||
request.model._llm_type = "anthropic-chat"
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
assert model_id == "claude-3-5-sonnet"
|
||||
assert provider == "anthropic-chat"
|
||||
|
||||
|
||||
class TestResolveToolName:
|
||||
def test_prefers_request_tool_name(self) -> None:
|
||||
request = MagicMock()
|
||||
request.tool = MagicMock(name="ToolStub")
|
||||
request.tool.name = "scrape_webpage"
|
||||
assert _resolve_tool_name(request) == "scrape_webpage"
|
||||
|
||||
def test_falls_back_to_tool_call_name(self) -> None:
|
||||
request = MagicMock()
|
||||
request.tool = None
|
||||
request.tool_call = {"name": "web_search", "args": {}}
|
||||
assert _resolve_tool_name(request) == "web_search"
|
||||
|
||||
def test_unknown_when_nothing_resolves(self) -> None:
|
||||
request = MagicMock()
|
||||
request.tool = None
|
||||
request.tool_call = {}
|
||||
assert _resolve_tool_name(request) == "unknown"
|
||||
|
||||
|
||||
class TestResolveInputSize:
|
||||
def test_returns_repr_length_of_args(self) -> None:
|
||||
request = MagicMock()
|
||||
request.tool_call = {"args": {"query": "hello world"}}
|
||||
size = _resolve_input_size(request)
|
||||
assert isinstance(size, int)
|
||||
assert size > 0
|
||||
|
||||
def test_handles_no_tool_call(self) -> None:
|
||||
request = MagicMock()
|
||||
request.tool_call = None
|
||||
assert _resolve_input_size(request) is None
|
||||
|
||||
|
||||
class TestAnnotateModelResponse:
|
||||
def test_attaches_token_counts_when_present(self) -> None:
|
||||
sp = MagicMock()
|
||||
msg = AIMessage(
|
||||
content="hello",
|
||||
usage_metadata={
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
)
|
||||
_annotate_model_response(sp, msg)
|
||||
sp.set_attribute.assert_any_call("tokens.prompt", 100)
|
||||
sp.set_attribute.assert_any_call("tokens.completion", 50)
|
||||
sp.set_attribute.assert_any_call("tokens.total", 150)
|
||||
|
||||
def test_handles_response_with_no_metadata(self) -> None:
|
||||
sp = MagicMock()
|
||||
msg = AIMessage(content="hello")
|
||||
# Should not raise even when usage_metadata is missing
|
||||
_annotate_model_response(sp, msg)
|
||||
|
||||
|
||||
class TestAnnotateToolResult:
|
||||
def test_records_size_and_status(self) -> None:
|
||||
sp = MagicMock()
|
||||
result = ToolMessage(
|
||||
content="result text",
|
||||
tool_call_id="abc",
|
||||
status="success",
|
||||
)
|
||||
_annotate_tool_result(sp, result)
|
||||
sp.set_attribute.assert_any_call("tool.output.size", len("result text"))
|
||||
sp.set_attribute.assert_any_call("tool.status", "success")
|
||||
|
||||
def test_marks_errors(self) -> None:
|
||||
sp = MagicMock()
|
||||
result = ToolMessage(
|
||||
content="oops",
|
||||
tool_call_id="abc",
|
||||
additional_kwargs={"error": {"code": "x"}},
|
||||
)
|
||||
_annotate_tool_result(sp, result)
|
||||
sp.set_attribute.assert_any_call("tool.error", True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMiddlewareIntegration:
|
||||
async def test_awrap_model_call_passes_through_when_disabled(self) -> None:
|
||||
mw = OtelSpanMiddleware()
|
||||
called: dict[str, Any] = {}
|
||||
|
||||
async def handler(req):
|
||||
called["req"] = req
|
||||
return AIMessage(content="ok")
|
||||
|
||||
request = MagicMock()
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
assert called["req"] is request
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "ok"
|
||||
|
||||
async def test_awrap_tool_call_passes_through_when_disabled(self) -> None:
|
||||
mw = OtelSpanMiddleware()
|
||||
|
||||
async def handler(req):
|
||||
return ToolMessage(content="result", tool_call_id="abc")
|
||||
|
||||
request = MagicMock()
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == "result"
|
||||
|
||||
async def test_awrap_model_call_propagates_exceptions(self) -> None:
|
||||
mw = OtelSpanMiddleware()
|
||||
|
||||
async def handler(req):
|
||||
raise ValueError("boom")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await mw.awrap_model_call(MagicMock(), handler)
|
||||
|
||||
async def test_with_otel_enabled_does_not_alter_result(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
from app.observability import otel as ot
|
||||
|
||||
ot.reload_for_tests()
|
||||
try:
|
||||
mw = OtelSpanMiddleware()
|
||||
|
||||
async def handler(req):
|
||||
return AIMessage(content="enabled")
|
||||
|
||||
request = MagicMock()
|
||||
request.model = MagicMock()
|
||||
request.model.model_name = "gpt-4o"
|
||||
request.model.provider = "openai"
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "enabled"
|
||||
finally:
|
||||
ot.reload_for_tests()
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""Tests for PermissionMiddleware end-to-end behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
config: dict = {"configurable": {"thread_id": "test"}}
|
||||
|
||||
|
||||
def _msg(*tool_calls: dict) -> AIMessage:
|
||||
return AIMessage(content="", tool_calls=list(tool_calls))
|
||||
|
||||
|
||||
class TestAllow:
|
||||
def test_passthrough_when_allow(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("send_email", "*", "allow")])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None # no change
|
||||
|
||||
|
||||
class TestDeny:
|
||||
def test_replaces_with_deny_tool_message(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("send_email", "*", "deny")])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is not None
|
||||
msgs = out["messages"]
|
||||
# Find the deny ToolMessage
|
||||
deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
|
||||
assert len(deny_msgs) == 1
|
||||
assert deny_msgs[0].status == "error"
|
||||
assert "permission_denied" in str(deny_msgs[0].additional_kwargs)
|
||||
# AIMessage's tool_calls should now be empty (denied call removed)
|
||||
ai_msg = next(m for m in msgs if isinstance(m, AIMessage))
|
||||
assert ai_msg.tool_calls == []
|
||||
|
||||
def test_mixed_allow_deny(self) -> None:
|
||||
rs = Ruleset(
|
||||
rules=[
|
||||
Rule("send_email", "*", "deny"),
|
||||
Rule("read", "*", "allow"),
|
||||
]
|
||||
)
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "send_email", "args": {}, "id": "1"},
|
||||
{"name": "read", "args": {}, "id": "2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is not None
|
||||
ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage))
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
assert ai_msg.tool_calls[0]["name"] == "read"
|
||||
|
||||
|
||||
class TestAsk:
|
||||
def test_reject_without_feedback_raises(self) -> None:
|
||||
# Default: nothing matches -> ask
|
||||
rs = Ruleset(rules=[])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
|
||||
# Bypass real interrupt — patch the helper
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
||||
def test_reject_with_feedback_raises_corrected(self) -> None:
|
||||
rs = Ruleset(rules=[])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||
"decision_type": "reject",
|
||||
"feedback": "use a different subject line",
|
||||
}
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(CorrectedError) as excinfo:
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
assert excinfo.value.feedback == "use a different subject line"
|
||||
|
||||
def test_once_proceeds_without_persisting(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
# No state change because all calls kept
|
||||
assert out is None
|
||||
# No new rule persisted
|
||||
assert mw._runtime_ruleset.rules == []
|
||||
|
||||
def test_always_persists_runtime_rule(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None # call kept
|
||||
# Runtime ruleset got the always-allow rule
|
||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||
assert any(
|
||||
r.permission == "send_email" for r in new_rules
|
||||
)
|
||||
111
surfsense_backend/tests/unit/agents/new_chat/test_permissions.py
Normal file
111
surfsense_backend/tests/unit/agents/new_chat/test_permissions.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for the wildcard matcher and rule evaluator (opencode evaluate.ts parity)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate,
|
||||
evaluate_many,
|
||||
wildcard_match,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestWildcardMatch:
|
||||
@pytest.mark.parametrize(
|
||||
"value,pattern,expected",
|
||||
[
|
||||
("edit", "edit", True),
|
||||
("edit", "*", True),
|
||||
("read", "edit", False),
|
||||
("/documents/secrets/x", "/documents/secrets/**", True),
|
||||
# Single-segment glob: '*' does not cross '/'
|
||||
("/documents/secrets/x", "/documents/*/x", True),
|
||||
("/documents/foo/bar/x", "/documents/*/x", False),
|
||||
("/documents/foo/x", "/documents/*/x", True),
|
||||
("linear_create", "linear_*", True),
|
||||
("notion_create", "linear_*", False),
|
||||
# ':' is not a separator, so '*' matches it
|
||||
("mcp:notion:create_page", "mcp:*", True),
|
||||
("mcp:notion:create_page", "mcp:**", True),
|
||||
# But '/' IS a separator
|
||||
("foo/bar", "foo/*", True),
|
||||
("foo/bar/baz", "foo/*", False),
|
||||
],
|
||||
)
|
||||
def test_match(self, value: str, pattern: str, expected: bool) -> None:
|
||||
assert wildcard_match(value, pattern) is expected
|
||||
|
||||
|
||||
class TestEvaluate:
|
||||
def test_default_action_is_ask(self) -> None:
|
||||
rule = evaluate("edit", "/foo/bar")
|
||||
assert rule.action == "ask"
|
||||
assert rule.permission == "edit"
|
||||
|
||||
def test_last_match_wins(self) -> None:
|
||||
rs = Ruleset(
|
||||
rules=[
|
||||
Rule("edit", "*", "allow"),
|
||||
Rule("edit", "/secrets/**", "deny"),
|
||||
]
|
||||
)
|
||||
# Second rule (deny) is more specific AND specified later
|
||||
assert evaluate("edit", "/secrets/x", rs).action == "deny"
|
||||
# First rule (allow) covers the rest
|
||||
assert evaluate("edit", "/public/x", rs).action == "allow"
|
||||
|
||||
def test_layered_rulesets_later_overrides_earlier(self) -> None:
|
||||
defaults = Ruleset(rules=[Rule("edit", "*", "ask")], origin="defaults")
|
||||
space = Ruleset(rules=[Rule("edit", "*", "allow")], origin="space")
|
||||
thread = Ruleset(rules=[Rule("edit", "*", "deny")], origin="thread")
|
||||
# All three layered: thread wins
|
||||
assert evaluate("edit", "x", defaults, space, thread).action == "deny"
|
||||
# Without thread: space wins
|
||||
assert evaluate("edit", "x", defaults, space).action == "allow"
|
||||
|
||||
def test_permission_wildcard(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("linear_*", "*", "allow")])
|
||||
assert evaluate("linear_create_issue", "x", rs).action == "allow"
|
||||
assert evaluate("notion_create", "x", rs).action == "ask"
|
||||
|
||||
def test_pattern_wildcard(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("edit", "/documents/secrets/**", "deny")])
|
||||
assert evaluate("edit", "/documents/secrets/foo", rs).action == "deny"
|
||||
assert evaluate("edit", "/documents/public/foo", rs).action == "ask"
|
||||
|
||||
def test_evaluate_many(self) -> None:
|
||||
rs = Ruleset(
|
||||
rules=[
|
||||
Rule("edit", "*", "allow"),
|
||||
Rule("edit", "/secrets/*", "deny"),
|
||||
]
|
||||
)
|
||||
results = evaluate_many("edit", ["/public/x", "/secrets/y"], rs)
|
||||
assert [r.action for r in results] == ["allow", "deny"]
|
||||
|
||||
|
||||
class TestAggregateAction:
|
||||
def test_any_deny_means_deny(self) -> None:
|
||||
rules = [
|
||||
Rule("a", "*", "allow"),
|
||||
Rule("a", "*", "deny"),
|
||||
Rule("a", "*", "ask"),
|
||||
]
|
||||
assert aggregate_action(rules) == "deny"
|
||||
|
||||
def test_any_ask_means_ask_when_no_deny(self) -> None:
|
||||
rules = [Rule("a", "*", "allow"), Rule("a", "*", "ask")]
|
||||
assert aggregate_action(rules) == "ask"
|
||||
|
||||
def test_all_allow_means_allow(self) -> None:
|
||||
rules = [Rule("a", "*", "allow"), Rule("a", "*", "allow")]
|
||||
assert aggregate_action(rules) == "allow"
|
||||
|
||||
def test_empty_means_ask(self) -> None:
|
||||
assert aggregate_action([]) == "ask"
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
"""Unit tests for the SurfSense plugin entry-point loader (Tier 6)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PLUGIN_ENTRY_POINT_GROUP,
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.plugins.year_substituter import (
|
||||
_YearSubstituterMiddleware,
|
||||
make_middleware as year_substituter_factory,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyMiddleware(AgentMiddleware):
|
||||
"""Trivial middleware used as the success-path return value."""
|
||||
|
||||
tools = ()
|
||||
|
||||
|
||||
def _ctx() -> PluginContext:
|
||||
return PluginContext.build(
|
||||
search_space_id=1,
|
||||
user_id="u",
|
||||
thread_visibility="PRIVATE", # type: ignore[arg-type]
|
||||
llm=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class _FakeEntryPoint:
|
||||
"""Stand-in for ``importlib.metadata.EntryPoint``."""
|
||||
|
||||
def __init__(self, name: str, factory) -> None:
|
||||
self.name = name
|
||||
self._factory = factory
|
||||
|
||||
def load(self):
|
||||
return self._factory
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loader behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginLoaderBasics:
|
||||
def test_returns_empty_when_allowlist_is_empty(self) -> None:
|
||||
assert load_plugin_middlewares(_ctx(), allowed_plugin_names=[]) == []
|
||||
|
||||
def test_skips_non_allowlisted_plugin(self) -> None:
|
||||
called = []
|
||||
|
||||
def factory(_): # would be an obvious bug if called
|
||||
called.append(True)
|
||||
return _DummyMiddleware()
|
||||
|
||||
ep = _FakeEntryPoint("dangerous_plugin", factory)
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[ep],
|
||||
):
|
||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"])
|
||||
assert result == []
|
||||
assert not called
|
||||
|
||||
def test_loads_allowlisted_plugin(self) -> None:
|
||||
ep = _FakeEntryPoint("year_substituter", year_substituter_factory)
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[ep],
|
||||
):
|
||||
result = load_plugin_middlewares(
|
||||
_ctx(), allowed_plugin_names={"year_substituter"}
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], _YearSubstituterMiddleware)
|
||||
|
||||
|
||||
class TestPluginLoaderIsolation:
|
||||
def test_factory_exception_is_isolated(self) -> None:
|
||||
def crashing_factory(_):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
ep = _FakeEntryPoint("buggy", crashing_factory)
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[ep],
|
||||
):
|
||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"})
|
||||
assert result == [] # construction continued without the plugin
|
||||
|
||||
def test_non_middleware_return_is_rejected(self) -> None:
|
||||
def bad_factory(_):
|
||||
return "not a middleware"
|
||||
|
||||
ep = _FakeEntryPoint("liar", bad_factory)
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[ep],
|
||||
):
|
||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"})
|
||||
assert result == []
|
||||
|
||||
def test_load_phase_exception_is_isolated(self) -> None:
|
||||
class _BrokenEP:
|
||||
name = "broken"
|
||||
|
||||
def load(self):
|
||||
raise ImportError("cannot import")
|
||||
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points",
|
||||
return_value=[_BrokenEP()],
|
||||
):
|
||||
result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"})
|
||||
assert result == []
|
||||
|
||||
def test_one_failure_does_not_block_others(self) -> None:
|
||||
"""Two plugins; one crashes during factory; the other still loads."""
|
||||
|
||||
def crashing_factory(_):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
eps = [
|
||||
_FakeEntryPoint("crashing", crashing_factory),
|
||||
_FakeEntryPoint("ok", year_substituter_factory),
|
||||
]
|
||||
with patch(
|
||||
"app.agents.new_chat.plugin_loader.entry_points", return_value=eps
|
||||
):
|
||||
result = load_plugin_middlewares(
|
||||
_ctx(), allowed_plugin_names={"crashing", "ok"}
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], _YearSubstituterMiddleware)
|
||||
|
||||
|
||||
class TestAllowlistEnv:
|
||||
def test_empty_env_returns_empty_set(self, monkeypatch) -> None:
|
||||
monkeypatch.delenv("SURFSENSE_ALLOWED_PLUGINS", raising=False)
|
||||
assert load_allowed_plugin_names_from_env() == set()
|
||||
|
||||
def test_parses_comma_separated_value(self, monkeypatch) -> None:
|
||||
monkeypatch.setenv(
|
||||
"SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , "
|
||||
)
|
||||
assert load_allowed_plugin_names_from_env() == {
|
||||
"year_substituter",
|
||||
"noisy",
|
||||
}
|
||||
|
||||
|
||||
class TestPluginContext:
|
||||
def test_build_includes_required_fields(self) -> None:
|
||||
llm = MagicMock()
|
||||
ctx = PluginContext.build(
|
||||
search_space_id=42,
|
||||
user_id="user-1",
|
||||
thread_visibility="PRIVATE", # type: ignore[arg-type]
|
||||
llm=llm,
|
||||
)
|
||||
assert ctx["search_space_id"] == 42
|
||||
assert ctx["user_id"] == "user-1"
|
||||
assert ctx["llm"] is llm
|
||||
|
||||
def test_does_not_carry_secrets_or_db_session(self) -> None:
|
||||
ctx = _ctx()
|
||||
# If a future change tries to add these keys, this test will fail loudly.
|
||||
for forbidden in ("api_key", "secret", "db_session", "session"):
|
||||
assert forbidden not in ctx
|
||||
|
||||
|
||||
class TestEntryPointGroup:
|
||||
def test_group_name_matches_pyproject_convention(self) -> None:
|
||||
# Plugins register under `surfsense.plugins`; this is part of our
|
||||
# public contract for plugin authors.
|
||||
assert PLUGIN_ENTRY_POINT_GROUP == "surfsense.plugins"
|
||||
107
surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py
Normal file
107
surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for RetryAfterMiddleware Retry-After parsing and retry decision logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.retry_after import (
|
||||
RetryAfterMiddleware,
|
||||
_extract_retry_after_seconds,
|
||||
_is_non_retryable,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, headers: dict[str, str]) -> None:
|
||||
self.headers = headers
|
||||
|
||||
|
||||
class _FakeRateLimit(Exception):
|
||||
def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None:
|
||||
super().__init__(msg)
|
||||
if headers is not None:
|
||||
self.response = _FakeResponse(headers)
|
||||
|
||||
|
||||
class TestExtractRetryAfter:
|
||||
def test_seconds_header(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"Retry-After": "30"})
|
||||
assert _extract_retry_after_seconds(exc) == 30.0
|
||||
|
||||
def test_milliseconds_header_overrides_seconds(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"})
|
||||
assert _extract_retry_after_seconds(exc) == 1.5
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"})
|
||||
assert _extract_retry_after_seconds(exc) == 12.0
|
||||
|
||||
def test_falls_back_to_message_regex(self) -> None:
|
||||
exc = Exception("Please retry after 7 seconds")
|
||||
assert _extract_retry_after_seconds(exc) == 7.0
|
||||
|
||||
def test_returns_none_when_no_hint(self) -> None:
|
||||
exc = Exception("oops")
|
||||
assert _extract_retry_after_seconds(exc) is None
|
||||
|
||||
def test_handles_missing_headers_attr(self) -> None:
|
||||
exc = ValueError("no headers")
|
||||
assert _extract_retry_after_seconds(exc) is None
|
||||
|
||||
|
||||
class TestIsNonRetryable:
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
["ContextWindowExceededError", "AuthenticationError", "InvalidRequestError"],
|
||||
)
|
||||
def test_non_retryable_classes(self, name: str) -> None:
|
||||
cls = type(name, (Exception,), {})
|
||||
assert _is_non_retryable(cls("x")) is True
|
||||
|
||||
def test_generic_exception_is_retryable(self) -> None:
|
||||
assert _is_non_retryable(RuntimeError("transient")) is False
|
||||
|
||||
|
||||
class TestDelayCalculation:
|
||||
def test_takes_max_of_backoff_and_header(self) -> None:
|
||||
mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False)
|
||||
exc = _FakeRateLimit("rl", {"retry-after": "10"})
|
||||
delay = mw._delay_for_attempt(0, exc)
|
||||
assert delay == pytest.approx(10.0)
|
||||
|
||||
def test_uses_backoff_when_no_header(self) -> None:
|
||||
mw = RetryAfterMiddleware(
|
||||
max_retries=3, initial_delay=2.0, backoff_factor=2.0, jitter=False
|
||||
)
|
||||
delay = mw._delay_for_attempt(2, RuntimeError("transient"))
|
||||
# 2 * 2^2 = 8
|
||||
assert delay == pytest.approx(8.0)
|
||||
|
||||
def test_caps_at_max_delay(self) -> None:
|
||||
mw = RetryAfterMiddleware(
|
||||
max_retries=3,
|
||||
initial_delay=10.0,
|
||||
backoff_factor=10.0,
|
||||
max_delay=15.0,
|
||||
jitter=False,
|
||||
)
|
||||
delay = mw._delay_for_attempt(5, RuntimeError("x"))
|
||||
assert delay <= 15.0
|
||||
|
||||
|
||||
class TestShouldRetry:
|
||||
def test_default_retries_generic(self) -> None:
|
||||
mw = RetryAfterMiddleware()
|
||||
assert mw._should_retry(RuntimeError("transient")) is True
|
||||
|
||||
def test_default_skips_non_retryable(self) -> None:
|
||||
mw = RetryAfterMiddleware()
|
||||
cls = type("ContextWindowExceededError", (Exception,), {})
|
||||
assert mw._should_retry(cls("too big")) is False
|
||||
|
||||
def test_custom_retry_on(self) -> None:
|
||||
mw = RetryAfterMiddleware(retry_on=lambda exc: isinstance(exc, ValueError))
|
||||
assert mw._should_retry(ValueError()) is True
|
||||
assert mw._should_retry(KeyError()) is False
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""Tests for the skills backends used by SurfSense's SkillsMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.skills_backends import (
|
||||
SKILLS_BUILTIN_PREFIX,
|
||||
SKILLS_SPACE_PREFIX,
|
||||
BuiltinSkillsBackend,
|
||||
SearchSpaceSkillsBackend,
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skills_root(tmp_path: Path) -> Path:
|
||||
"""Build a small synthetic skill-tree used by the tests."""
|
||||
root = tmp_path / "skills"
|
||||
(root / "alpha").mkdir(parents=True)
|
||||
(root / "alpha" / "SKILL.md").write_text(
|
||||
"---\nname: alpha\ndescription: alpha skill\n---\n# Alpha\n"
|
||||
)
|
||||
(root / "beta").mkdir(parents=True)
|
||||
(root / "beta" / "SKILL.md").write_text(
|
||||
"---\nname: beta\ndescription: beta skill\n---\n# Beta\n"
|
||||
)
|
||||
(root / "_orphan_file.md").write_text("not a skill, just a stray file")
|
||||
return root
|
||||
|
||||
|
||||
class TestBuiltinSkillsBackendListing:
|
||||
def test_lists_skill_directories_at_root(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
infos = backend.ls_info("/")
|
||||
names = {info["path"] for info in infos}
|
||||
assert "/alpha" in names
|
||||
assert "/beta" in names
|
||||
assert "/_orphan_file.md" in names
|
||||
for info in infos:
|
||||
if info["path"] in {"/alpha", "/beta"}:
|
||||
assert info["is_dir"] is True
|
||||
|
||||
def test_lists_skill_md_under_skill_directory(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
infos = backend.ls_info("/alpha")
|
||||
paths = {info["path"] for info in infos}
|
||||
assert paths == {"/alpha/SKILL.md"}
|
||||
assert infos[0]["is_dir"] is False
|
||||
assert infos[0]["size"] > 0
|
||||
|
||||
def test_returns_empty_for_missing_path(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
assert backend.ls_info("/nonexistent") == []
|
||||
|
||||
def test_returns_empty_when_root_missing(self, tmp_path: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(tmp_path / "definitely-missing")
|
||||
assert backend.ls_info("/") == []
|
||||
assert backend.download_files(["/x/SKILL.md"])[0].error == "file_not_found"
|
||||
|
||||
def test_refuses_path_traversal(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
assert backend.ls_info("/../../../etc") == []
|
||||
responses = backend.download_files(["/../../../etc/passwd"])
|
||||
assert responses[0].error == "invalid_path"
|
||||
|
||||
|
||||
class TestBuiltinSkillsBackendDownload:
|
||||
def test_downloads_skill_md_content(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
responses = backend.download_files(["/alpha/SKILL.md", "/beta/SKILL.md"])
|
||||
assert len(responses) == 2
|
||||
assert responses[0].path == "/alpha/SKILL.md"
|
||||
assert responses[0].content is not None
|
||||
assert b"name: alpha" in responses[0].content
|
||||
assert responses[1].error is None
|
||||
|
||||
def test_marks_directory_as_is_directory_error(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
responses = backend.download_files(["/alpha"])
|
||||
assert responses[0].error == "is_directory"
|
||||
|
||||
def test_marks_missing_file_as_file_not_found(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
responses = backend.download_files(["/alpha/missing.md"])
|
||||
assert responses[0].error == "file_not_found"
|
||||
assert responses[0].content is None
|
||||
|
||||
def test_response_path_matches_input_for_correlation(
|
||||
self, skills_root: Path
|
||||
) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
inputs = ["/alpha/SKILL.md", "/missing.md", "/beta/SKILL.md"]
|
||||
responses = backend.download_files(inputs)
|
||||
assert [r.path for r in responses] == inputs
|
||||
|
||||
|
||||
class TestBuiltinSkillsBackendIntegration:
|
||||
"""Mirror the call sequence the SkillsMiddleware actually uses."""
|
||||
|
||||
def test_skills_middleware_call_pattern(self, skills_root: Path) -> None:
|
||||
backend = BuiltinSkillsBackend(skills_root)
|
||||
|
||||
infos = asyncio.run(backend.als_info("/"))
|
||||
skill_dirs = [i["path"] for i in infos if i.get("is_dir")]
|
||||
assert sorted(skill_dirs) == ["/alpha", "/beta"]
|
||||
|
||||
skill_md_paths = [f"{p}/SKILL.md" for p in skill_dirs]
|
||||
responses = asyncio.run(backend.adownload_files(skill_md_paths))
|
||||
assert all(r.error is None for r in responses)
|
||||
assert all(r.content is not None for r in responses)
|
||||
|
||||
|
||||
class TestBundledSkills:
|
||||
def test_default_root_resolves_to_repo_skills_dir(self) -> None:
|
||||
backend = BuiltinSkillsBackend()
|
||||
assert backend.root.name == "builtin"
|
||||
assert backend.root.parent.name == "skills"
|
||||
|
||||
def test_bundled_starter_skills_are_present(self) -> None:
|
||||
backend = BuiltinSkillsBackend()
|
||||
infos = backend.ls_info("/")
|
||||
names = {info["path"].lstrip("/") for info in infos if info.get("is_dir")}
|
||||
# Five starter skills required by the Tier 4 plan.
|
||||
for required in (
|
||||
"kb-research",
|
||||
"report-writing",
|
||||
"meeting-prep",
|
||||
"slack-summary",
|
||||
"email-drafting",
|
||||
):
|
||||
assert required in names, f"missing starter skill: {required}"
|
||||
|
||||
def test_each_starter_skill_has_valid_skill_md(self) -> None:
|
||||
backend = BuiltinSkillsBackend()
|
||||
infos = backend.ls_info("/")
|
||||
skill_dirs = [info["path"] for info in infos if info.get("is_dir")]
|
||||
for skill_dir in skill_dirs:
|
||||
md_path = f"{skill_dir}/SKILL.md"
|
||||
response = backend.download_files([md_path])[0]
|
||||
assert response.error is None, f"missing SKILL.md in {skill_dir}"
|
||||
content = response.content.decode("utf-8").replace("\r\n", "\n")
|
||||
assert content.startswith("---\n"), f"missing frontmatter in {skill_dir}"
|
||||
assert "\nname:" in content
|
||||
assert "\ndescription:" in content
|
||||
|
||||
|
||||
class _FakeKBBackend:
|
||||
"""Stand-in for :class:`KBPostgresBackend` with the two methods we need."""
|
||||
|
||||
def __init__(self, listing: list[dict], file_contents: dict[str, bytes]) -> None:
|
||||
self._listing = listing
|
||||
self._file_contents = file_contents
|
||||
self.last_ls_path: str | None = None
|
||||
self.last_download_paths: list[str] | None = None
|
||||
|
||||
async def als_info(self, path: str):
|
||||
self.last_ls_path = path
|
||||
return self._listing
|
||||
|
||||
async def adownload_files(self, paths):
|
||||
from deepagents.backends.protocol import FileDownloadResponse
|
||||
|
||||
self.last_download_paths = list(paths)
|
||||
out: list[FileDownloadResponse] = []
|
||||
for p in paths:
|
||||
content = self._file_contents.get(p)
|
||||
if content is None:
|
||||
out.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||
else:
|
||||
out.append(FileDownloadResponse(path=p, content=content))
|
||||
return out
|
||||
|
||||
|
||||
class TestSearchSpaceSkillsBackend:
|
||||
def test_remaps_paths_when_listing(self) -> None:
|
||||
listing = [
|
||||
{"path": "/documents/_skills/policy", "is_dir": True},
|
||||
{"path": "/documents/_skills/policy/SKILL.md", "is_dir": False},
|
||||
{"path": "/documents/other-folder/x.md", "is_dir": False},
|
||||
]
|
||||
kb = _FakeKBBackend(listing=listing, file_contents={})
|
||||
backend = SearchSpaceSkillsBackend(kb)
|
||||
infos = asyncio.run(backend.als_info("/"))
|
||||
assert kb.last_ls_path == "/documents/_skills"
|
||||
paths = [info["path"] for info in infos]
|
||||
assert "/policy" in paths
|
||||
assert "/policy/SKILL.md" in paths
|
||||
# Unrelated KB documents must NOT leak into the skills namespace.
|
||||
assert all(not p.startswith("/documents") for p in paths)
|
||||
|
||||
def test_remaps_paths_when_downloading(self) -> None:
|
||||
kb = _FakeKBBackend(
|
||||
listing=[],
|
||||
file_contents={
|
||||
"/documents/_skills/policy/SKILL.md": b"---\nname: policy\n---\n",
|
||||
},
|
||||
)
|
||||
backend = SearchSpaceSkillsBackend(kb)
|
||||
responses = asyncio.run(backend.adownload_files(["/policy/SKILL.md"]))
|
||||
assert kb.last_download_paths == ["/documents/_skills/policy/SKILL.md"]
|
||||
assert responses[0].path == "/policy/SKILL.md"
|
||||
assert responses[0].error is None
|
||||
assert responses[0].content is not None
|
||||
|
||||
def test_sync_methods_raise_not_implemented(self) -> None:
|
||||
backend = SearchSpaceSkillsBackend(_FakeKBBackend([], {}))
|
||||
with pytest.raises(NotImplementedError):
|
||||
backend.ls_info("/")
|
||||
with pytest.raises(NotImplementedError):
|
||||
backend.download_files(["/x"])
|
||||
|
||||
def test_custom_kb_root_is_honored(self) -> None:
|
||||
kb = _FakeKBBackend(
|
||||
listing=[
|
||||
{"path": "/skills_admin/x", "is_dir": True},
|
||||
],
|
||||
file_contents={},
|
||||
)
|
||||
backend = SearchSpaceSkillsBackend(kb, kb_root="/skills_admin")
|
||||
infos = asyncio.run(backend.als_info("/"))
|
||||
assert kb.last_ls_path == "/skills_admin"
|
||||
assert infos[0]["path"] == "/x"
|
||||
|
||||
|
||||
class TestBackendFactory:
|
||||
def test_builtin_only_factory_returns_composite(self) -> None:
|
||||
factory = build_skills_backend_factory()
|
||||
backend = factory(runtime=None) # type: ignore[arg-type]
|
||||
from deepagents.backends.composite import CompositeBackend
|
||||
|
||||
assert isinstance(backend, CompositeBackend)
|
||||
assert SKILLS_BUILTIN_PREFIX in backend.routes
|
||||
assert SKILLS_SPACE_PREFIX not in backend.routes
|
||||
|
||||
def test_default_skills_sources_lists_builtin_then_space(self) -> None:
|
||||
sources = default_skills_sources()
|
||||
assert sources == [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
"""Tests for the specialized subagents (explore / report_writer / connector_negotiator)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.subagents import (
|
||||
build_connector_negotiator_subagent,
|
||||
build_explore_subagent,
|
||||
build_report_writer_subagent,
|
||||
build_specialized_subagents,
|
||||
)
|
||||
from app.agents.new_chat.subagents.config import (
|
||||
EXPLORE_READ_TOOLS,
|
||||
REPORT_WRITER_TOOLS,
|
||||
WRITE_TOOL_DENY_PATTERNS,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake tools used to verify filtering & permission behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@tool
|
||||
def search_surfsense_docs(query: str) -> str:
|
||||
"""Search the user's KB."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str) -> str:
|
||||
"""Search the public web."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def scrape_webpage(url: str) -> str:
|
||||
"""Scrape a single webpage."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def read_file(path: str) -> str:
|
||||
"""Read a file."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def ls_tree(path: str) -> str:
|
||||
"""List a tree."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def grep(pattern: str) -> str:
|
||||
"""Grep."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def update_memory(content: str) -> str:
|
||||
"""Update the user's memory."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def edit_file(path: str, old: str, new: str) -> str:
|
||||
"""Edit a file."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def linear_create_issue(title: str) -> str:
|
||||
"""Create a Linear issue."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def slack_send_message(channel: str, text: str) -> str:
|
||||
"""Send a Slack message."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def get_connected_accounts() -> str:
|
||||
"""List connected accounts."""
|
||||
return ""
|
||||
|
||||
|
||||
@tool
|
||||
def generate_report(topic: str) -> str:
|
||||
"""Generate a report artifact."""
|
||||
return ""
|
||||
|
||||
|
||||
ALL_TOOLS = [
|
||||
search_surfsense_docs,
|
||||
web_search,
|
||||
scrape_webpage,
|
||||
read_file,
|
||||
ls_tree,
|
||||
grep,
|
||||
update_memory,
|
||||
edit_file,
|
||||
linear_create_issue,
|
||||
slack_send_message,
|
||||
get_connected_accounts,
|
||||
generate_report,
|
||||
]
|
||||
|
||||
|
||||
class TestExploreSubagent:
|
||||
def test_only_read_tools_are_exposed(self) -> None:
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
names = {t.name for t in spec["tools"]} # type: ignore[index]
|
||||
assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS}
|
||||
assert "update_memory" not in names
|
||||
assert "linear_create_issue" not in names
|
||||
assert "edit_file" not in names
|
||||
|
||||
def test_includes_permission_middleware_with_deny_rules(self) -> None:
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
assert len(permission_mws) == 1
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
assert ruleset.origin == "subagent_explore"
|
||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||
assert "update_memory" in deny_patterns
|
||||
assert "edit_file" in deny_patterns
|
||||
assert "*create*" in deny_patterns
|
||||
assert "*send*" in deny_patterns
|
||||
|
||||
def test_skills_inherits_default_sources(self) -> None:
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index]
|
||||
|
||||
def test_name_and_description_match_contract(self) -> None:
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
assert spec["name"] == "explore"
|
||||
assert "read-only" in spec["description"].lower()
|
||||
|
||||
def test_includes_dedup_and_patch_middleware(self) -> None:
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
spec = build_explore_subagent(tools=ALL_TOOLS)
|
||||
types = {type(m) for m in spec["middleware"]} # type: ignore[index]
|
||||
assert PatchToolCallsMiddleware in types
|
||||
assert DedupHITLToolCallsMiddleware in types
|
||||
|
||||
|
||||
class TestReportWriterSubagent:
|
||||
def test_exposes_only_report_writing_tools(self) -> None:
|
||||
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
||||
names = {t.name for t in spec["tools"]} # type: ignore[index]
|
||||
assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS}
|
||||
assert "generate_report" in names
|
||||
assert "search_surfsense_docs" in names
|
||||
|
||||
def test_deny_rules_block_writes_but_allow_generate_report(self) -> None:
|
||||
spec = build_report_writer_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||
assert "update_memory" in deny_patterns
|
||||
# generate_report MUST not be denied — it's the whole point of the subagent.
|
||||
assert "generate_report" not in deny_patterns
|
||||
# No deny pattern should match `generate_report` either.
|
||||
assert all(
|
||||
not _wildcard_matches(pattern, "generate_report")
|
||||
for pattern in deny_patterns
|
||||
)
|
||||
|
||||
|
||||
class TestConnectorNegotiatorSubagent:
|
||||
def test_inherits_all_parent_tools(self) -> None:
|
||||
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
||||
names = {t.name for t in spec["tools"]} # type: ignore[index]
|
||||
# Every parent tool is inherited; the deny ruleset enforces behavior
|
||||
# at execution time instead of trimming the tool list.
|
||||
assert names == {t.name for t in ALL_TOOLS}
|
||||
|
||||
def test_get_connected_accounts_is_present(self) -> None:
|
||||
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
||||
names = {t.name for t in spec["tools"]} # type: ignore[index]
|
||||
assert "get_connected_accounts" in names
|
||||
|
||||
def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None:
|
||||
spec = build_connector_negotiator_subagent(tools=ALL_TOOLS)
|
||||
permission_mws = [
|
||||
m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index]
|
||||
]
|
||||
ruleset = permission_mws[0]._static_rulesets[0]
|
||||
deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"}
|
||||
# `linear_create_issue` matches the `*_create` deny pattern.
|
||||
assert any(
|
||||
_wildcard_matches(p, "linear_create_issue") for p in deny_patterns
|
||||
)
|
||||
assert any(
|
||||
_wildcard_matches(p, "slack_send_message") for p in deny_patterns
|
||||
)
|
||||
|
||||
|
||||
class TestBuildSpecializedSubagents:
|
||||
def test_returns_three_specs(self) -> None:
|
||||
specs = build_specialized_subagents(tools=ALL_TOOLS)
|
||||
names = [s["name"] for s in specs] # type: ignore[index]
|
||||
assert names == ["explore", "report_writer", "connector_negotiator"]
|
||||
|
||||
def test_all_specs_have_unique_names(self) -> None:
|
||||
specs = build_specialized_subagents(tools=ALL_TOOLS)
|
||||
names = [s["name"] for s in specs] # type: ignore[index]
|
||||
assert len(set(names)) == len(names)
|
||||
|
||||
def test_extra_middleware_is_prepended_to_each_spec(self) -> None:
|
||||
"""Sentinel middleware passed via ``extra_middleware`` must appear
|
||||
in each subagent's ``middleware`` list, before the local rules.
|
||||
|
||||
This guards against the regression where specialized subagents
|
||||
promised filesystem tools (``read_file``, ``ls``, ``grep``) in
|
||||
their system prompts but had no filesystem middleware mounted.
|
||||
"""
|
||||
|
||||
class _Sentinel:
|
||||
pass
|
||||
|
||||
sentinel = _Sentinel()
|
||||
specs = build_specialized_subagents(
|
||||
tools=ALL_TOOLS, extra_middleware=[sentinel]
|
||||
)
|
||||
for spec in specs:
|
||||
mws = spec["middleware"] # type: ignore[index]
|
||||
assert sentinel in mws
|
||||
# The sentinel must appear *before* the permission middleware
|
||||
# (subagent-local rules), preserving the documented composition
|
||||
# order: extra → custom → patch → dedup.
|
||||
sentinel_idx = mws.index(sentinel)
|
||||
perm_idx = next(
|
||||
(i for i, m in enumerate(mws)
|
||||
if isinstance(m, PermissionMiddleware)),
|
||||
None,
|
||||
)
|
||||
assert perm_idx is not None
|
||||
assert sentinel_idx < perm_idx
|
||||
|
||||
|
||||
class TestFilterToolsWarningSuppression:
|
||||
"""Names provided by middleware (read_file, ls, grep, …) must not
|
||||
trigger the spurious "missing" warning in :func:`_filter_tools`."""
|
||||
|
||||
def test_middleware_provided_names_are_silent(self, caplog) -> None:
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.subagents.config import _filter_tools
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"):
|
||||
# Allowed set asks for two registry tools (one present, one
|
||||
# not) plus a bunch of middleware-provided names.
|
||||
_filter_tools(
|
||||
[search_surfsense_docs],
|
||||
allowed_names={
|
||||
"search_surfsense_docs",
|
||||
"scrape_webpage", # legitimately missing → should warn
|
||||
"read_file", # mw-provided → suppressed
|
||||
"ls",
|
||||
"grep",
|
||||
"glob",
|
||||
"write_todos",
|
||||
},
|
||||
)
|
||||
|
||||
warnings = [
|
||||
r.message for r in caplog.records if r.levelno >= logging.INFO
|
||||
]
|
||||
# Exactly one warning, and it should mention scrape_webpage but not
|
||||
# any middleware-provided name. Inspect the rendered "missing"
|
||||
# list (between the brackets) so we don't false-match substrings
|
||||
# like ``ls`` inside ``available``.
|
||||
assert len(warnings) == 1, warnings
|
||||
msg = warnings[0]
|
||||
assert "scrape_webpage" in msg
|
||||
bracket_section = msg.split("missing: ", 1)[1]
|
||||
for noisy in ("read_file", "ls", "grep", "glob", "write_todos"):
|
||||
assert f"'{noisy}'" not in bracket_section, msg
|
||||
|
||||
|
||||
class TestDenyPatternsCoverage:
|
||||
def test_deny_patterns_cover_canonical_write_tools(self) -> None:
|
||||
canonical_writes = [
|
||||
"update_memory",
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"linear_create_issue",
|
||||
"linear_update_issue",
|
||||
"linear_delete_issue",
|
||||
"slack_send_message",
|
||||
"create_index",
|
||||
"update_account",
|
||||
"delete_record",
|
||||
"send_email",
|
||||
]
|
||||
for tool_name in canonical_writes:
|
||||
assert any(
|
||||
_wildcard_matches(pattern, tool_name)
|
||||
for pattern in WRITE_TOOL_DENY_PATTERNS
|
||||
), f"no deny pattern matches {tool_name!r}"
|
||||
|
||||
def test_deny_patterns_do_not_match_safe_read_tools(self) -> None:
|
||||
canonical_reads = [
|
||||
"search_surfsense_docs",
|
||||
"read_file",
|
||||
"ls_tree",
|
||||
"grep",
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"get_connected_accounts",
|
||||
"generate_report",
|
||||
]
|
||||
for tool_name in canonical_reads:
|
||||
assert not any(
|
||||
_wildcard_matches(pattern, tool_name)
|
||||
for pattern in WRITE_TOOL_DENY_PATTERNS
|
||||
), f"deny pattern incorrectly matches read tool {tool_name!r}"
|
||||
|
||||
|
||||
def _wildcard_matches(pattern: str, value: str) -> bool:
|
||||
"""Helper using the same matcher the rule evaluator does."""
|
||||
from app.agents.new_chat.permissions import wildcard_match
|
||||
|
||||
return wildcard_match(value, pattern)
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
"""Tests for ToolCallNameRepairMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.new_chat.middleware.tool_call_repair import (
|
||||
ToolCallNameRepairMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_state(message: AIMessage) -> dict:
|
||||
return {"messages": [message]}
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, context: object | None = None) -> None:
|
||||
self.context = context
|
||||
|
||||
|
||||
class TestRepair:
|
||||
def test_passthrough_when_name_matches(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "echo", "args": {}, "id": "1"},
|
||||
])
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is None # no change
|
||||
|
||||
def test_lowercase_repair(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "Echo", "args": {"x": 1}, "id": "1"},
|
||||
])
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
repaired = out["messages"][0]
|
||||
assert repaired.tool_calls[0]["name"] == "echo"
|
||||
|
||||
def test_invalid_fallback_when_no_match(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo", INVALID_TOOL_NAME},
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "totally_different_name", "args": {"k": "v"}, "id": "1"},
|
||||
])
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
repaired_call = out["messages"][0].tool_calls[0]
|
||||
assert repaired_call["name"] == INVALID_TOOL_NAME
|
||||
assert repaired_call["args"]["tool"] == "totally_different_name"
|
||||
assert "totally_different_name" in repaired_call["args"]["error"]
|
||||
|
||||
def test_no_invalid_means_skip_when_unknown(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "unknown", "args": {}, "id": "1"},
|
||||
])
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
# No repair available; original returned unchanged (no update)
|
||||
assert out is None
|
||||
|
||||
def test_fuzzy_match_works_when_enabled(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"search_documents"},
|
||||
fuzzy_match_threshold=0.7,
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "search_docments", "args": {}, "id": "1"},
|
||||
])
|
||||
out = mw.after_model(_make_state(msg), _FakeRuntime())
|
||||
assert out is not None
|
||||
assert out["messages"][0].tool_calls[0]["name"] == "search_documents"
|
||||
|
||||
def test_skips_when_no_messages(self) -> None:
|
||||
mw = ToolCallNameRepairMiddleware(registered_tool_names={"echo"})
|
||||
out = mw.after_model({"messages": []}, _FakeRuntime())
|
||||
assert out is None
|
||||
|
||||
def test_runtime_context_extends_registered(self) -> None:
|
||||
from types import SimpleNamespace
|
||||
|
||||
mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names={"echo"}, fuzzy_match_threshold=None
|
||||
)
|
||||
msg = AIMessage(content="", tool_calls=[
|
||||
{"name": "DynamicTool", "args": {}, "id": "1"},
|
||||
])
|
||||
runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"]))
|
||||
out = mw.after_model(_make_state(msg), runtime)
|
||||
assert out is not None
|
||||
assert out["messages"][0].tool_calls[0]["name"] == "dynamictool"
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -14,9 +16,34 @@ def _make_state(tool_calls: list[dict]) -> dict:
|
|||
return {"messages": [msg]}
|
||||
|
||||
|
||||
def _hitl_tool(name: str, *, dedup_arg: str) -> StructuredTool:
|
||||
"""Build a tool with declarative ``dedup_key`` metadata.
|
||||
|
||||
Mirrors the ``ToolDefinition.dedup_key`` -> ``tool.metadata["dedup_key"]``
|
||||
propagation done by :func:`build_tools` after the cleanup tier.
|
||||
"""
|
||||
|
||||
def _fn(**kwargs):
|
||||
return "ok"
|
||||
|
||||
return StructuredTool.from_function(
|
||||
func=_fn,
|
||||
name=name,
|
||||
description="x",
|
||||
metadata={"dedup_key": wrap_dedup_key_by_arg_name(dedup_arg)},
|
||||
)
|
||||
|
||||
|
||||
def test_duplicate_hitl_calls_reduced_to_first():
|
||||
"""When the LLM emits the same HITL tool call twice, only the first is kept."""
|
||||
mw = DedupHITLToolCallsMiddleware()
|
||||
"""When the LLM emits the same HITL tool call twice, only the first is kept.
|
||||
|
||||
After the cleanup tier removed ``_NATIVE_HITL_TOOL_DEDUP_KEYS``, the
|
||||
resolver is sourced from ``ToolDefinition.dedup_key`` propagated onto
|
||||
``tool.metadata`` — which the registry does at agent build time. The
|
||||
test mirrors that wiring with an in-memory tool.
|
||||
"""
|
||||
tool = _hitl_tool("delete_calendar_event", dedup_arg="event_title_or_id")
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
|
||||
state = _make_state(
|
||||
[
|
||||
|
|
|
|||
1
surfsense_backend/tests/unit/observability/__init__.py
Normal file
1
surfsense_backend/tests/unit/observability/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
84
surfsense_backend/tests/unit/observability/test_otel.py
Normal file
84
surfsense_backend/tests/unit/observability/test_otel.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Tests for the SurfSense OpenTelemetry shim (Tier 3b)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.observability import otel
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_otel_state(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Force a clean OTel disabled state per test, then restore after."""
|
||||
for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"):
|
||||
monkeypatch.delenv(env, raising=False)
|
||||
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
|
||||
otel.reload_for_tests()
|
||||
yield
|
||||
otel.reload_for_tests()
|
||||
|
||||
|
||||
def test_disabled_by_default_when_no_endpoint() -> None:
|
||||
assert otel.is_enabled() is False
|
||||
|
||||
|
||||
def test_enabled_when_endpoint_configured(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
assert otel.reload_for_tests() is True
|
||||
|
||||
|
||||
def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
|
||||
assert otel.reload_for_tests() is False
|
||||
|
||||
|
||||
class TestNoopSpansWhenDisabled:
|
||||
def test_generic_span_yields_noop(self) -> None:
|
||||
with otel.span("any.thing", attributes={"x": 1}) as sp:
|
||||
sp.set_attribute("y", 2)
|
||||
sp.set_attributes({"a": "b"})
|
||||
sp.add_event("evt")
|
||||
sp.record_exception(RuntimeError("ignored"))
|
||||
sp.set_status("ignored")
|
||||
# Reaching here without raising means the no-op is well-formed
|
||||
|
||||
def test_exception_propagates_through_span(self) -> None:
|
||||
with pytest.raises(ValueError), otel.span("err"):
|
||||
raise ValueError("boom")
|
||||
|
||||
def test_each_helper_is_a_no_op_when_disabled(self) -> None:
|
||||
helpers = [
|
||||
otel.tool_call_span("write_file", input_size=42),
|
||||
otel.model_call_span(model_id="openai:gpt-4o", provider="openai"),
|
||||
otel.kb_search_span(search_space_id=1, query_chars=99),
|
||||
otel.kb_persist_span(document_type="NOTE", document_id=7),
|
||||
otel.compaction_span(reason="overflow", messages_in=120),
|
||||
otel.interrupt_span(interrupt_type="permission_ask"),
|
||||
otel.permission_asked_span(permission="edit", pattern="/x/**"),
|
||||
]
|
||||
for cm in helpers:
|
||||
with cm as sp:
|
||||
assert sp is not None
|
||||
sp.set_attribute("ok", True)
|
||||
|
||||
|
||||
class TestEnabledIntegration:
|
||||
"""When OTel is wired but no SDK exporter is bound, the API still works."""
|
||||
|
||||
def test_span_attaches_attributes(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Use the API tracer (no-op-ish but real Span objects).
|
||||
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
assert otel.reload_for_tests() is True
|
||||
|
||||
# Should not raise even when set_attributes/record_exception fall through
|
||||
# to an SDK that isn't actually installed.
|
||||
with otel.tool_call_span("scrape_webpage", input_size=10) as sp:
|
||||
sp.set_attribute("tool.output.size", 200)
|
||||
sp.set_attribute("tool.truncated", False)
|
||||
with otel.model_call_span(model_id="m", provider="p") as sp:
|
||||
sp.set_attribute("retry.count", 3)
|
||||
56
surfsense_backend/tests/unit/services/test_revert_service.py
Normal file
56
surfsense_backend/tests/unit/services/test_revert_service.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Unit tests for the agent revert service (Tier 5.3)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.services.revert_service import can_revert
|
||||
|
||||
|
||||
class _FakeAction:
|
||||
def __init__(self, *, user_id: Any, tool_name: str = "edit_file") -> None:
|
||||
self.user_id = user_id
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
class TestCanRevert:
|
||||
def test_owner_can_revert_their_own_action(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
assert can_revert(
|
||||
requester_user_id="user-123", action=action, is_admin=False
|
||||
)
|
||||
|
||||
def test_other_user_cannot_revert(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
assert not can_revert(
|
||||
requester_user_id="someone-else", action=action, is_admin=False
|
||||
)
|
||||
|
||||
def test_admin_always_allowed(self) -> None:
|
||||
action = _FakeAction(user_id="user-123")
|
||||
assert can_revert(
|
||||
requester_user_id="anybody", action=action, is_admin=True
|
||||
)
|
||||
|
||||
def test_admin_can_revert_anonymous_action(self) -> None:
|
||||
action = _FakeAction(user_id=None)
|
||||
assert can_revert(
|
||||
requester_user_id="admin", action=action, is_admin=True
|
||||
)
|
||||
|
||||
def test_anonymous_action_blocks_non_admin(self) -> None:
|
||||
action = _FakeAction(user_id=None)
|
||||
assert not can_revert(
|
||||
requester_user_id="user-1", action=action, is_admin=False
|
||||
)
|
||||
|
||||
def test_uuid_string_normalization(self) -> None:
|
||||
"""``user_id`` may be a UUID object; comparison should still work."""
|
||||
import uuid
|
||||
|
||||
u = uuid.uuid4()
|
||||
action = _FakeAction(user_id=u)
|
||||
# Same UUID, passed as string from the requesting side.
|
||||
assert can_revert(
|
||||
requester_user_id=str(u), action=action, is_admin=False
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue