mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat: improved agent streaming
This commit is contained in:
parent
afb4b09cde
commit
c110f5b955
60 changed files with 8068 additions and 303 deletions
|
|
@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
|||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRuntime:
|
||||
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
|
||||
|
||||
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
|
||||
to populate the new ``chat_turn_id`` column (see migration 135).
|
||||
"""
|
||||
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRequest:
|
||||
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
||||
|
|
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
|
|||
"args": {"color": "red", "size": 3},
|
||||
"id": "tc-abc",
|
||||
},
|
||||
runtime=_FakeRuntime(
|
||||
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||
),
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
|
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
|
|||
assert row.error is None
|
||||
assert row.reverse_descriptor is None
|
||||
assert row.reversible is False
|
||||
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
|
||||
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
|
||||
assert row.tool_call_id == "tc-abc"
|
||||
assert row.turn_id == "tc-abc"
|
||||
assert row.chat_turn_id == "42:1700000000000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_turn_id_none_when_runtime_missing(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
|
||||
runtime=None,
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
assert row.tool_call_id == "tc-1"
|
||||
assert row.chat_turn_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_row_on_failure_and_reraises(
|
||||
|
|
@ -293,6 +333,76 @@ class TestReverseDescriptor:
|
|||
assert row.reversible is False
|
||||
|
||||
|
||||
class TestActionLogDispatch:
|
||||
"""Verify ``adispatch_custom_event`` fires after commit."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_action_log_event_on_success(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
_captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||
request = _FakeRequest(
|
||||
tool_call={
|
||||
"name": "make_widget",
|
||||
"args": {"color": "red"},
|
||||
"id": "tc-evt",
|
||||
},
|
||||
runtime=_FakeRuntime(
|
||||
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||
),
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
dispatch_mock = AsyncMock()
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
patch(
|
||||
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||
dispatch_mock,
|
||||
),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
dispatch_mock.assert_awaited_once()
|
||||
call_args = dispatch_mock.await_args
|
||||
assert call_args is not None
|
||||
assert call_args.args[0] == "action_log"
|
||||
payload = call_args.args[1]
|
||||
assert payload["lc_tool_call_id"] == "tc-evt"
|
||||
assert payload["chat_turn_id"] == "42:1700000000000"
|
||||
assert payload["tool_name"] == "make_widget"
|
||||
assert payload["reversible"] is False
|
||||
assert payload["reverse_descriptor_present"] is False
|
||||
assert payload["error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
|
||||
"""If commit fails the dispatch is suppressed (no row to surface)."""
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
dispatch_mock = AsyncMock()
|
||||
|
||||
def _exploding_session():
|
||||
raise RuntimeError("DB is down")
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||
patch(
|
||||
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||
dispatch_mock,
|
||||
),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
dispatch_mock.assert_not_awaited()
|
||||
|
||||
|
||||
class TestArgsTruncation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_huge_args_payload_is_truncated(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for the desktop-mode safety ruleset.
|
||||
|
||||
In desktop mode the agent operates against the user's real disk with no
|
||||
revision history, so destructive filesystem operations must require
|
||||
explicit approval. These tests pin the set of tools that get the ``ask``
|
||||
gate so it cannot silently regress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
|
||||
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
|
||||
# copy here means the rule contract has a focused regression test even when
|
||||
# the larger graph-build helper is hard to instantiate in unit tests.
|
||||
DESKTOP_SAFETY_RULESET = Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
|
||||
SURFSENSE_DEFAULTS = Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
|
||||
|
||||
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
|
||||
rules = evaluate_many(tool_name, [tool_name], *rulesets)
|
||||
return aggregate_action(rules)
|
||||
|
||||
|
||||
class TestDesktopSafetyRulesGateDestructiveOps:
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
["rm", "rmdir", "move_file", "edit_file", "write_file"],
|
||||
)
|
||||
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
|
||||
# surfsense_defaults says "allow */*"; desktop_safety must override
|
||||
# because it's layered later (last-match-wins).
|
||||
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "ask", (
|
||||
f"{tool_name} must require approval in desktop mode "
|
||||
f"(no revert path on real disk); got {action!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
|
||||
)
|
||||
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
|
||||
# Read-only and trivially-reversible tools must NOT get gated —
|
||||
# otherwise every navigation in desktop mode pops an interrupt.
|
||||
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "allow", (
|
||||
f"{tool_name} should not be gated in desktop mode; got {action!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestDesktopSafetyOverridesAllowDefault:
|
||||
def test_layer_order_last_match_wins(self) -> None:
|
||||
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
|
||||
# default would win and the safety net would be inert. This test
|
||||
# protects against accidentally swapping the rulesets in
|
||||
# ``_build_compiled_agent_blocking``.
|
||||
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
|
||||
# Layered "wrong way" — the broad allow now wins.
|
||||
assert action == "allow"
|
||||
|
||||
# Correct order: defaults < desktop_safety -> ask wins.
|
||||
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "ask"
|
||||
|
||||
|
||||
class TestPermissionMiddlewareIntegration:
|
||||
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.new_chat.errors import RejectedError
|
||||
|
||||
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
|
||||
# Stub the interrupt to a "reject" decision so we can assert the
|
||||
# ask path was taken without spinning up the LangGraph runtime.
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/Users/me/Documents/important.docx"},
|
||||
"id": "tc-rm",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
class _FakeRuntime:
|
||||
config: dict = {"configurable": {"thread_id": "test"}}
|
||||
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for the default auto-approval list in ``hitl.request_approval``.
|
||||
|
||||
These pin the policy that low-stakes connector creation tools (drafts,
|
||||
new-file creates) skip the HITL interrupt by default. Without this set,
|
||||
every "draft my newsletter" turn used to fire ~3 interrupts before any
|
||||
useful work happened.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools.hitl import (
|
||||
DEFAULT_AUTO_APPROVED_TOOLS,
|
||||
HITLResult,
|
||||
request_approval,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestDefaultAutoApprovedToolsList:
|
||||
def test_set_contains_expected_creation_tools(self) -> None:
|
||||
# If anyone changes the policy list, we want a single test to
|
||||
# update so the contract is explicit. Keep this in sync with
|
||||
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
|
||||
expected = {
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
"create_dropbox_file",
|
||||
"create_onedrive_file",
|
||||
}
|
||||
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
|
||||
|
||||
def test_set_is_immutable(self) -> None:
|
||||
# frozenset prevents accidental at-runtime mutation that would
|
||||
# silently widen the auto-approval surface.
|
||||
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
|
||||
|
||||
def test_send_tools_are_not_auto_approved(self) -> None:
|
||||
# External-broadcast tools must always prompt.
|
||||
for tool_name in (
|
||||
"send_gmail_email",
|
||||
"send_discord_message",
|
||||
"send_teams_message",
|
||||
"delete_notion_page",
|
||||
"create_calendar_event",
|
||||
"delete_calendar_event",
|
||||
):
|
||||
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
|
||||
f"{tool_name} must remain HITL-gated"
|
||||
)
|
||||
|
||||
|
||||
class TestRequestApprovalAutoBypass:
|
||||
def test_auto_approved_tool_skips_interrupt(self) -> None:
|
||||
# No interrupt mock set up — if the function attempted to call
|
||||
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
|
||||
# The fact that we get a clean HITLResult proves the bypass.
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||
)
|
||||
assert isinstance(result, HITLResult)
|
||||
assert result.rejected is False
|
||||
assert result.decision_type == "auto_approved"
|
||||
# Original params are preserved untouched (no user edits possible).
|
||||
assert result.params == {
|
||||
"to": "alice@example.com",
|
||||
"subject": "hi",
|
||||
"body": "hey",
|
||||
}
|
||||
|
||||
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
|
||||
# A tool NOT in the default list must reach ``langgraph.interrupt``.
|
||||
# Outside a runnable context that call raises a RuntimeError —
|
||||
# which is exactly the signal we want: the bypass did NOT fire.
|
||||
with pytest.raises(RuntimeError, match="runnable context"):
|
||||
request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||
)
|
||||
|
||||
def test_user_trusted_tools_still_take_precedence(self) -> None:
|
||||
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
|
||||
# was checked BEFORE the default list and must keep working
|
||||
# for tools outside the default list.
|
||||
result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name="my_custom_mcp_tool",
|
||||
params={"x": 1},
|
||||
trusted_tools=["my_custom_mcp_tool"],
|
||||
)
|
||||
assert result.decision_type == "trusted"
|
||||
assert result.rejected is False
|
||||
|
||||
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
|
||||
# When trusted_tools is empty and tool is in the default list,
|
||||
# we should still bypass — proves the order in request_approval.
|
||||
result = request_approval(
|
||||
action_type="notion_page_creation",
|
||||
tool_name="create_notion_page",
|
||||
params={"title": "Plan"},
|
||||
trusted_tools=[],
|
||||
)
|
||||
assert result.decision_type == "auto_approved"
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
|
||||
|
||||
The tools build ``Command(update=...)`` payloads that the persistence
|
||||
middleware applies at end of turn. These tests stub out the backend and
|
||||
runtime to assert the staging payload shape:
|
||||
|
||||
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
|
||||
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
|
||||
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
|
||||
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
|
||||
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
|
||||
* ``KBPostgresBackend`` view-helpers honor staged deletes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._filesystem_mode = mode
|
||||
middleware._custom_tool_descriptions = {}
|
||||
return middleware
|
||||
|
||||
|
||||
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
|
||||
state = state or {}
|
||||
state.setdefault("cwd", "/documents")
|
||||
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
|
||||
|
||||
|
||||
class _KBBackendStub(KBPostgresBackend):
|
||||
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
|
||||
|
||||
We bypass the real ``__init__`` (which expects a runtime + DB session)
|
||||
and inject just the methods the rm/rmdir tools touch. The class
|
||||
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
|
||||
inside the tools happy, which is what gates them from the desktop
|
||||
code path.
|
||||
"""
|
||||
|
||||
def __init__(self, *, children=None, file_data=None) -> None:
|
||||
self.als_info = AsyncMock(return_value=children or [])
|
||||
self._load_file_data = AsyncMock(
|
||||
return_value=(file_data, 17) if file_data is not None else None
|
||||
)
|
||||
|
||||
|
||||
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
|
||||
return _KBBackendStub(children=children, file_data=file_data)
|
||||
|
||||
|
||||
def _bind_backend(middleware, backend):
|
||||
"""Inject a backend resolver onto the middleware test instance."""
|
||||
middleware._get_backend = lambda runtime: backend
|
||||
return backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rm
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmStaging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stages_delete_and_tombstones_state(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
"files": {"/documents/notes.md": {"content": ["hello"]}},
|
||||
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||
},
|
||||
tool_call_id="tc-1",
|
||||
)
|
||||
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update"), f"expected Command, got {result!r}"
|
||||
update = result.update
|
||||
assert update["pending_deletes"] == [
|
||||
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
|
||||
]
|
||||
assert update["files"] == {"/documents/notes.md": None}
|
||||
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_documents_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rm" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rm" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_directory_via_staged_dirs(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime(
|
||||
{
|
||||
"staged_dirs": ["/documents/team-x"],
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/team-x", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "directory" in result.lower()
|
||||
assert "rmdir" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_directory_via_listing(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(
|
||||
m,
|
||||
_make_backend_stub(
|
||||
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
|
||||
),
|
||||
)
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/foo", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "directory" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_anonymous_doc(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime(
|
||||
{
|
||||
"kb_anon_doc": {
|
||||
"path": "/documents/uploaded.xml",
|
||||
"title": "uploaded",
|
||||
"content": "",
|
||||
"chunks": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "read-only" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drops_path_from_dirty_paths(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"files": {"/documents/notes.md": {"content": ["x"]}},
|
||||
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||
"dirty_paths": ["/documents/notes.md"],
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
update = result.update
|
||||
# First element is _CLEAR sentinel; the rest must NOT contain the
|
||||
# rm'd path.
|
||||
dirty = update.get("dirty_paths") or []
|
||||
assert "/documents/notes.md" not in dirty[1:]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rmdir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmdirStaging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stages_dir_delete_when_empty_and_db_backed(self):
|
||||
m = _make_middleware()
|
||||
backend = _bind_backend(m, _make_backend_stub(children=[]))
|
||||
# Override _load_file_data to return None (folder, not a file) and
|
||||
# parent listing to claim the folder exists.
|
||||
backend._load_file_data = AsyncMock(return_value=None)
|
||||
backend.als_info = AsyncMock(
|
||||
side_effect=[
|
||||
[], # children of /documents/proj
|
||||
[
|
||||
{"path": "/documents/proj", "is_dir": True},
|
||||
], # parent listing
|
||||
]
|
||||
)
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
},
|
||||
tool_call_id="tc-rd",
|
||||
)
|
||||
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
assert update["pending_dir_deletes"] == [
|
||||
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_non_empty(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(
|
||||
m,
|
||||
_make_backend_stub(
|
||||
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
|
||||
),
|
||||
)
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "not empty" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unstages_same_turn_mkdir(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[]))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
"staged_dirs": ["/documents/scratch"],
|
||||
},
|
||||
tool_call_id="tc-rd",
|
||||
)
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/scratch", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
assert "pending_dir_deletes" not in update
|
||||
# _CLEAR sentinel + remaining items (in this case, none).
|
||||
staged_after = update["staged_dirs"]
|
||||
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||
assert "/documents/scratch" not in staged_after[1:]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
for victim in ("/", "/documents"):
|
||||
result = await tool.coroutine(victim, runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rmdir" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/proj"})
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "cwd" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_ancestor_of_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/proj/sub"})
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "cwd" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_files(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "is a file" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KBPostgresBackend view filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKBPostgresBackendDeleteFilter:
|
||||
"""als_info / glob / grep should suppress paths queued for delete."""
|
||||
|
||||
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
|
||||
runtime = SimpleNamespace(state=state)
|
||||
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
|
||||
return backend
|
||||
|
||||
def test_pending_filesystem_view_returns_deleted_paths(self):
|
||||
backend = self._make_backend(
|
||||
{
|
||||
"pending_deletes": [
|
||||
{"path": "/documents/x.md", "tool_call_id": "t1"},
|
||||
],
|
||||
"pending_dir_deletes": [
|
||||
{"path": "/documents/d1", "tool_call_id": "t2"},
|
||||
],
|
||||
}
|
||||
)
|
||||
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
|
||||
assert "/documents/x.md" in removed
|
||||
assert "/documents/d1" in deleted_dirs
|
||||
assert alias == {}
|
||||
|
||||
def test_dir_suppressed_covers_descendants(self):
|
||||
backend = self._make_backend({})
|
||||
deleted_dirs = {"/documents/d"}
|
||||
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
|
||||
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
|
||||
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
|
||||
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)
|
||||
|
|
@ -98,10 +98,54 @@ class TestInitialFilesystemState:
|
|||
state = _initial_filesystem_state()
|
||||
assert state["cwd"] == "/documents"
|
||||
assert state["staged_dirs"] == []
|
||||
assert state["staged_dir_tool_calls"] == {}
|
||||
assert state["pending_moves"] == []
|
||||
assert state["pending_deletes"] == []
|
||||
assert state["pending_dir_deletes"] == []
|
||||
assert state["doc_id_by_path"] == {}
|
||||
assert state["dirty_paths"] == []
|
||||
assert state["dirty_path_tool_calls"] == {}
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_matched_chunk_ids"] == {}
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
||||
|
||||
class TestMultiEditSamePathCoalescing:
|
||||
"""Multi-edit-same-path turns must coalesce into ONE binding record.
|
||||
|
||||
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
|
||||
tool_call_id that produced the current state on disk. Because
|
||||
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
|
||||
edit doesn't append a new path entry — and because
|
||||
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
|
||||
overwrite, the LATEST tool_call_id wins. That's the correct behavior
|
||||
for snapshotting: revert restores to the pre-mutation state, and
|
||||
multiple back-to-back edits in one turn coalesce into a single
|
||||
revisible op (the user sees ONE Revert button per turn-per-path,
|
||||
not N).
|
||||
"""
|
||||
|
||||
def test_dirty_paths_dedupes_repeated_writes(self):
|
||||
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
|
||||
# to the same path produce one entry, not two.
|
||||
first = _add_unique_reducer([], ["/documents/a.md"])
|
||||
second = _add_unique_reducer(first, ["/documents/a.md"])
|
||||
assert second == ["/documents/a.md"]
|
||||
|
||||
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
|
||||
# First write tags the path with tcid-1.
|
||||
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
|
||||
# Second write to the same path tags it with tcid-2 (latest wins).
|
||||
merged = _dict_merge_with_tombstones_reducer(
|
||||
merged, {"/documents/a.md": "tcid-2"}
|
||||
)
|
||||
assert merged == {"/documents/a.md": "tcid-2"}
|
||||
|
||||
def test_rm_tombstones_dirty_path_tool_call(self):
|
||||
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
|
||||
# prevent a stale binding from leaking past the delete.
|
||||
merged = _dict_merge_with_tombstones_reducer(
|
||||
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
|
||||
)
|
||||
assert merged == {}
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
|
||||
|
||||
A full apply/rollback test would require a live Postgres; here we verify
|
||||
the migration module's static contract:
|
||||
|
||||
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
|
||||
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
|
||||
(one for ``document_revisions.document_id``, one for
|
||||
``folder_revisions.folder_id``).
|
||||
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
|
||||
orphaned revisions.
|
||||
|
||||
If any of these invariants regress the snapshot/revert pipeline silently
|
||||
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
|
||||
migration "down" or never ran it at all.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_MIGRATION_PATH = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "134_relax_revision_fks.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_migration():
|
||||
"""Load the migration module by file path (no package import needed)."""
|
||||
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
|
||||
assert spec and spec.loader, "could not load migration spec"
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_migration_chain_revision_ids() -> None:
|
||||
module = _load_migration()
|
||||
# The migration file uses short numeric revision IDs to match the
|
||||
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
|
||||
# filename is documentation, not the canonical revision string.
|
||||
assert getattr(module, "revision", None) == "134"
|
||||
assert getattr(module, "down_revision", None) == "133"
|
||||
|
||||
|
||||
def test_migration_exposes_upgrade_and_downgrade() -> None:
|
||||
module = _load_migration()
|
||||
upgrade = getattr(module, "upgrade", None)
|
||||
downgrade = getattr(module, "downgrade", None)
|
||||
assert callable(upgrade), "upgrade() is required"
|
||||
assert callable(downgrade), "downgrade() is required"
|
||||
|
||||
|
||||
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
|
||||
module = _load_migration()
|
||||
src = inspect.getsource(module.upgrade)
|
||||
assert "document_revisions" in src
|
||||
assert "folder_revisions" in src
|
||||
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
|
||||
# of the migration: snapshots must outlive their parent row.
|
||||
assert src.count('ondelete="SET NULL"') >= 2
|
||||
# And the ``document_id`` / ``folder_id`` columns become nullable.
|
||||
assert "nullable=True" in src
|
||||
|
||||
|
||||
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
|
||||
module = _load_migration()
|
||||
src = inspect.getsource(module.downgrade)
|
||||
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
|
||||
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
|
||||
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
|
||||
# Then restore the original CASCADE/NOT NULL contract.
|
||||
assert src.count('ondelete="CASCADE"') >= 2
|
||||
assert "nullable=False" in src
|
||||
|
|
@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
|
|||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
|
|
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
|
|||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
|
|
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
|
|||
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||
|
||||
def test_cloud_descs_include_rm_and_rmdir(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||
assert "rm" in descs and "rmdir" in descs
|
||||
assert "Deletes a single file" in descs["rm"]
|
||||
assert "Deletes an empty directory" in descs["rmdir"]
|
||||
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
|
||||
|
||||
def test_desktop_descs_warn_about_irreversibility(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||
assert "NOT reversible" in descs["rm"]
|
||||
assert "NOT reversible" in descs["rmdir"]
|
||||
|
||||
def test_sandbox_addendum_appended_when_available(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.CLOUD, sandbox_available=True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,309 @@
|
|||
"""Unit tests for the kb_persistence snapshot helpers.
|
||||
|
||||
The full ``commit_staged_filesystem_state`` body exercises a real session
|
||||
in integration tests; here we verify the building blocks used by the
|
||||
snapshot/revert pipeline:
|
||||
|
||||
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
|
||||
(regression guard against the N+1 lookup pattern).
|
||||
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
|
||||
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
|
||||
shape the snapshot helpers consume.
|
||||
|
||||
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
|
||||
the assertions run in milliseconds and don't require Postgres.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware import kb_persistence
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||
self._rows = rows or []
|
||||
self._scalar = scalar
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self._rows)
|
||||
|
||||
def scalar_one_or_none(self) -> Any:
|
||||
return self._scalar
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_issues_single_query() -> None:
|
||||
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
|
||||
session = _FakeSession()
|
||||
session.execute.return_value = _FakeResult(
|
||||
rows=[
|
||||
MagicMock(id=11, tool_call_id="tc-a"),
|
||||
MagicMock(id=22, tool_call_id="tc-b"),
|
||||
MagicMock(id=33, tool_call_id="tc-c"),
|
||||
]
|
||||
)
|
||||
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=1,
|
||||
tool_call_ids={"tc-a", "tc-b", "tc-c"},
|
||||
)
|
||||
|
||||
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
|
||||
assert session.execute.await_count == 1, (
|
||||
"Snapshot binding must batch into ONE query; got "
|
||||
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
|
||||
session = _FakeSession()
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=None,
|
||||
tool_call_ids={"tc-a"},
|
||||
)
|
||||
assert mapping == {}
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
|
||||
session = _FakeSession()
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=42,
|
||||
tool_call_ids=set(),
|
||||
)
|
||||
assert mapping == {}
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
|
||||
session = _FakeSession()
|
||||
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
|
||||
session = _FakeSession()
|
||||
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
|
||||
assert session.execute.await_count == 1
|
||||
|
||||
|
||||
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
|
||||
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
|
||||
doc = MagicMock()
|
||||
doc.content = "body"
|
||||
doc.title = "notes.md"
|
||||
doc.folder_id = 7
|
||||
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
|
||||
|
||||
payload = kb_persistence._doc_revision_payload(
|
||||
doc, chunks_before=[{"content": "x"}]
|
||||
)
|
||||
|
||||
assert payload["title_before"] == "notes.md"
|
||||
assert payload["folder_id_before"] == 7
|
||||
assert payload["content_before"] == "body"
|
||||
assert payload["chunks_before"] == [{"content": "x"}]
|
||||
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
|
||||
|
||||
|
||||
def test_doc_revision_payload_handles_missing_metadata() -> None:
|
||||
doc = MagicMock()
|
||||
doc.content = ""
|
||||
doc.title = ""
|
||||
doc.folder_id = None
|
||||
doc.document_metadata = None
|
||||
payload = kb_persistence._doc_revision_payload(doc)
|
||||
assert payload["metadata_before"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
|
||||
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
|
||||
session = _FakeSession()
|
||||
session.execute.return_value = _FakeResult(
|
||||
rows=[
|
||||
MagicMock(content="alpha"),
|
||||
MagicMock(content="beta"),
|
||||
]
|
||||
)
|
||||
chunks = await kb_persistence._load_chunks_for_snapshot(
|
||||
session,
|
||||
doc_id=42, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deferred reversibility-flip dispatches.
|
||||
#
|
||||
# The snapshot helpers used to dispatch ``action_log_updated`` directly
|
||||
# from inside the SAVEPOINT block. That meant the SSE side-channel
|
||||
# could tell the UI a row was reversible while the OUTER transaction
|
||||
# was still pending — and if the outer commit failed, every SAVEPOINT
|
||||
# rolled back too, leaving the UI in a state inconsistent with
|
||||
# durable storage. The deferred-dispatch contract fixes that:
|
||||
#
|
||||
# • when a ``deferred_dispatches`` list is provided, the helper
|
||||
# APPENDS the action_id and does NOT dispatch;
|
||||
# • the caller (``commit_staged_filesystem_state``) flushes the list
|
||||
# only AFTER ``await session.commit()`` succeeds; on rollback it
|
||||
# clears the list so nothing is emitted.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _NestedCtx:
|
||||
"""Async context manager mimicking ``session.begin_nested()``."""
|
||||
|
||||
async def __aenter__(self) -> _NestedCtx:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 17
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
deferred: list[int] = []
|
||||
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
|
||||
doc.title = "x.md"
|
||||
doc.folder_id = None
|
||||
doc.content = "body"
|
||||
|
||||
rev_id = await kb_persistence._snapshot_document_pre_write(
|
||||
session, # type: ignore[arg-type]
|
||||
doc=doc,
|
||||
action_id=42,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
deferred_dispatches=deferred,
|
||||
)
|
||||
|
||||
assert rev_id == 17
|
||||
# Inline dispatch must NOT have fired; the action_id is queued.
|
||||
assert dispatched == []
|
||||
assert deferred == [42]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 7
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
|
||||
doc.title = "y.md"
|
||||
doc.folder_id = None
|
||||
doc.content = "body"
|
||||
|
||||
await kb_persistence._snapshot_document_pre_write(
|
||||
session, # type: ignore[arg-type]
|
||||
doc=doc,
|
||||
action_id=88,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
# No deferred_dispatches arg — fall back to inline dispatch.
|
||||
)
|
||||
|
||||
assert dispatched == [88]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock() # _mark_action_reversible calls execute
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 3
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
deferred: list[int] = []
|
||||
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
|
||||
|
||||
await kb_persistence._snapshot_folder_pre_mkdir(
|
||||
session, # type: ignore[arg-type]
|
||||
folder=folder,
|
||||
action_id=55,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
deferred_dispatches=deferred,
|
||||
)
|
||||
|
||||
assert dispatched == []
|
||||
assert deferred == [55]
|
||||
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
|
||||
|
||||
The empty-folder marker is critical UX: without it, the LLM cannot
|
||||
distinguish a leaf folder containing one document from a leaf folder
|
||||
that has no descendants at all, and ends up firing ``rmdir`` on
|
||||
non-empty folders. These tests pin the rendering contract so that
|
||||
contract cannot silently regress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
|
||||
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
|
||||
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
|
||||
class TestComputeNonEmptyFolders:
|
||||
def test_folder_with_direct_document_is_non_empty(self):
|
||||
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||
doc_paths = [
|
||||
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
|
||||
]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
|
||||
|
||||
def test_truly_empty_leaf_folder_is_not_non_empty(self):
|
||||
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||
doc_paths: list[str] = []
|
||||
assert _compute(folder_paths, doc_paths) == set()
|
||||
|
||||
def test_documents_propagate_up_to_all_ancestors(self):
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/A",
|
||||
f"{DOCUMENTS_ROOT}/A/B",
|
||||
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||
]
|
||||
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
assert non_empty == {
|
||||
f"{DOCUMENTS_ROOT}/A",
|
||||
f"{DOCUMENTS_ROOT}/A/B",
|
||||
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||
}
|
||||
|
||||
def test_chain_with_subfolders_marks_only_leaf_empty(self):
|
||||
# POSIX-like semantic: a folder is "empty" only if it has no
|
||||
# immediate children (docs OR sub-folders). The model needs this
|
||||
# because parallel ``rmdir`` calls all see the same starting state,
|
||||
# so trying to rmdir a parent before its children is never safe.
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/X",
|
||||
f"{DOCUMENTS_ROOT}/X/Y",
|
||||
f"{DOCUMENTS_ROOT}/X/Y/Z",
|
||||
]
|
||||
non_empty = _compute(folder_paths, [])
|
||||
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
|
||||
# sub-folder child, so they are non-empty and should NOT carry the
|
||||
# ``(empty)`` marker.
|
||||
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
|
||||
|
||||
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
|
||||
# Mirrors a real DB layout where every intermediate folder is
|
||||
# materialized in the ``folders`` table.
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/Travel",
|
||||
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
|
||||
f"{DOCUMENTS_ROOT}/Travel/Notes",
|
||||
]
|
||||
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
|
||||
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
|
||||
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
|
||||
|
||||
|
||||
class TestFormatTreeRendering:
|
||||
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
|
||||
|
||||
def _render(
|
||||
self,
|
||||
folder_paths: list[str],
|
||||
doc_specs: list[dict],
|
||||
) -> str:
|
||||
from app.agents.new_chat.path_resolver import PathIndex
|
||||
|
||||
index = PathIndex(
|
||||
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
|
||||
)
|
||||
|
||||
class _Row:
|
||||
def __init__(self, **kw):
|
||||
self.__dict__.update(kw)
|
||||
|
||||
docs = [_Row(**spec) for spec in doc_specs]
|
||||
|
||||
mw = KnowledgeTreeMiddleware(
|
||||
search_space_id=1,
|
||||
filesystem_mode=None, # type: ignore[arg-type]
|
||||
)
|
||||
return mw._format_tree(index, docs)
|
||||
|
||||
def test_renders_empty_marker_only_for_truly_empty_folders(self):
|
||||
# Reproduces the failure scenario from the bug report:
|
||||
# ``Boarding Pass`` is empty (its only doc was just deleted), while
|
||||
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
|
||||
# folders are present in the index, mirroring the real DB layout.
|
||||
folder_paths = [
|
||||
"/documents/File Upload",
|
||||
"/documents/File Upload/2026-04-08",
|
||||
"/documents/File Upload/2026-04-08/Travel",
|
||||
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
|
||||
"/documents/File Upload/2026-04-15",
|
||||
"/documents/File Upload/2026-04-15/Finance",
|
||||
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
|
||||
]
|
||||
tax_returns_folder_id = (
|
||||
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
|
||||
+ 1
|
||||
)
|
||||
rendered = self._render(
|
||||
folder_paths=folder_paths,
|
||||
doc_specs=[
|
||||
{
|
||||
"id": 100,
|
||||
"title": "federal.pdf",
|
||||
"folder_id": tax_returns_folder_id,
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Boarding Pass/ (empty)" in rendered
|
||||
assert "Tax Returns/ (empty)" not in rendered
|
||||
# Intermediate ancestors of the doc must NOT be marked empty.
|
||||
assert "Finance/ (empty)" not in rendered
|
||||
assert "2026-04-15/ (empty)" not in rendered
|
||||
|
|
@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
|
|||
assert write.error is not None
|
||||
assert "parent directory" in write.error
|
||||
assert not (tmp_path / "tempoo").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_success(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "delete-me.md").write_text("bye")
|
||||
|
||||
res = backend.delete_file("/delete-me.md")
|
||||
assert res.error is None
|
||||
assert res.path == "/delete-me.md"
|
||||
assert not (tmp_path / "delete-me.md").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "subdir").mkdir()
|
||||
|
||||
res = backend.delete_file("/subdir")
|
||||
assert res.error is not None
|
||||
assert "directory" in res.error
|
||||
assert (tmp_path / "subdir").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
res = backend.delete_file("/nope.md")
|
||||
assert res.error is not None
|
||||
assert "not found" in res.error
|
||||
|
||||
|
||||
def test_local_backend_rmdir_success(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "empty").mkdir()
|
||||
|
||||
res = backend.rmdir("/empty")
|
||||
assert res.error is None
|
||||
assert res.path == "/empty"
|
||||
assert not (tmp_path / "empty").exists()
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "withkid").mkdir()
|
||||
(tmp_path / "withkid" / "child.md").write_text("x")
|
||||
|
||||
res = backend.rmdir("/withkid")
|
||||
assert res.error is not None
|
||||
assert "not empty" in res.error
|
||||
assert (tmp_path / "withkid" / "child.md").exists()
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "f.md").write_text("x")
|
||||
|
||||
res = backend.rmdir("/f.md")
|
||||
assert res.error is not None
|
||||
assert "not a directory" in res.error
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
|
||||
"""``rmdir /`` MUST fail. The exact error wording comes from
|
||||
``_resolve_virtual`` (root resolves to outside the sandbox); what
|
||||
matters is that the call returns an error and does NOT delete the
|
||||
sandbox root on disk."""
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
res = backend.rmdir("/")
|
||||
assert res.error is not None
|
||||
assert "Invalid path" in res.error or "root" in res.error
|
||||
assert tmp_path.exists()
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
|
||||
|
||||
The regenerate route's edit-from-position path introduces:
|
||||
* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples
|
||||
newest-first and picks the first one whose ``metadata["turn_id"]``
|
||||
differs from the edited turn. That checkpoint is the rewind target
|
||||
(state immediately before the edited turn started).
|
||||
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
|
||||
with a validator that prevents callers from requesting a revert pass
|
||||
without specifying which turn to roll back.
|
||||
|
||||
These are pure-Python helpers that don't need a live DB, so we exercise
|
||||
them with a small ``CheckpointTuple``-shaped namespace and direct
|
||||
schema instantiation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
|
||||
from app.schemas.new_chat import RegenerateRequest
|
||||
|
||||
|
||||
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
|
||||
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
|
||||
return SimpleNamespace(
|
||||
config={"configurable": {"checkpoint_id": checkpoint_id}},
|
||||
metadata={"turn_id": turn_id} if turn_id is not None else {},
|
||||
)
|
||||
|
||||
|
||||
class TestFindPreTurnCheckpointId:
|
||||
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
|
||||
# Newest-first: T2 is the most-recent turn. The latest non-T2
|
||||
# checkpoint (cp2) is the rewind target — state immediately
|
||||
# before T2 began.
|
||||
tuples = [
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
|
||||
# Regression for the bug where walking newest-first returned the
|
||||
# FIRST cp with ``turn_id != target`` — which is one of the
|
||||
# later-turn checkpoints, NOT the pre-turn boundary. Editing
|
||||
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
|
||||
# latest T3 checkpoint (cp6).
|
||||
tuples = [
|
||||
_cp("cp6", "T3"),
|
||||
_cp("cp5", "T3"),
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_returns_none_when_editing_first_turn(self) -> None:
|
||||
# No pre-turn boundary exists; caller is expected to fall back
|
||||
# to the oldest checkpoint or special-case "first turn of the
|
||||
# thread".
|
||||
tuples = [
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
|
||||
|
||||
def test_returns_none_when_only_edited_turn_present(self) -> None:
|
||||
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
|
||||
|
||||
def test_returns_none_for_empty_history(self) -> None:
|
||||
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
|
||||
|
||||
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
|
||||
# Checkpoints written before migration 136 have no
|
||||
# ``metadata.turn_id``. They should be eligible rewind targets
|
||||
# — they came before the
|
||||
# edited turn began.
|
||||
tuples = [
|
||||
_cp("cp3", "T2"),
|
||||
SimpleNamespace(
|
||||
config={"configurable": {"checkpoint_id": "cp2"}},
|
||||
metadata=None,
|
||||
),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
|
||||
# then cp3(T2) crosses the boundary -> return cp2.
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
|
||||
# If a checkpoint tuple's ``config["configurable"]`` is missing
|
||||
# the ``checkpoint_id`` key (corrupt / partial), we keep the
|
||||
# last known good target instead of crashing.
|
||||
broken = SimpleNamespace(
|
||||
config={"configurable": {}}, metadata={"turn_id": "T1"}
|
||||
)
|
||||
tuples = [
|
||||
_cp("cp3", "T2"),
|
||||
broken,
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
|
||||
|
||||
|
||||
class TestRegenerateRequestValidation:
|
||||
def test_revert_actions_requires_from_message_id(self) -> None:
|
||||
with pytest.raises(Exception) as exc:
|
||||
RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
revert_actions=True,
|
||||
)
|
||||
msg = str(exc.value).lower()
|
||||
assert "from_message_id" in msg
|
||||
|
||||
def test_from_message_id_without_revert_is_allowed(self) -> None:
|
||||
req = RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
from_message_id=42,
|
||||
)
|
||||
assert req.from_message_id == 42
|
||||
assert req.revert_actions is False
|
||||
|
||||
def test_revert_actions_with_from_message_id_passes(self) -> None:
|
||||
req = RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
from_message_id=42,
|
||||
revert_actions=True,
|
||||
)
|
||||
assert req.revert_actions is True
|
||||
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||
|
||||
The per-turn batch revert route walks rows in reverse ``created_at``
|
||||
order, reverts each independently, and returns a per-action result
|
||||
list. Partial success is normal — the response status
|
||||
is ``"partial"`` whenever any row could not be reverted, but we never
|
||||
collapse the whole batch into a 4xx.
|
||||
|
||||
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
|
||||
session, so they exercise the route's dispatch logic without a real DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.routes import agent_revert_route
|
||||
from app.services.revert_service import RevertOutcome
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeAction:
|
||||
id: int
|
||||
tool_name: str
|
||||
user_id: str | None = "u1"
|
||||
reverse_of: int | None = None
|
||||
error: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeUser:
|
||||
id: str = "u1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ScalarResult:
|
||||
rows: list[Any]
|
||||
|
||||
def first(self) -> Any:
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Result:
|
||||
rows: list[Any] = field(default_factory=list)
|
||||
|
||||
def scalars(self) -> _ScalarResult:
|
||||
return _ScalarResult(self.rows)
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
|
||||
# the row-tuple result (no ``.scalars()`` indirection). The
|
||||
# rows queued for that helper are list[(revert_id, original_id)].
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _FakeNestedCtx:
|
||||
"""Async context manager that mimics ``session.begin_nested()``.
|
||||
|
||||
The route raises a sentinel exception inside this block to roll back
|
||||
bad rows. We just pass the exception through.
|
||||
"""
|
||||
|
||||
async def __aenter__(self) -> _FakeNestedCtx:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
# Returning False (or None) propagates the exception; the route
|
||||
# catches its own sentinel above this layer.
|
||||
return False
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stand-in for the revert-turn route.
|
||||
|
||||
Holds a queue of result objects; each ``execute(...)`` pops the next
|
||||
one. The route calls ``execute`` exactly once per query so this maps
|
||||
cleanly onto the assertion order of the test.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._results: list[_Result] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
# Count execute() calls to assert "no N+1 reverts".
|
||||
self.execute_call_count = 0
|
||||
|
||||
def queue(self, *results: _Result) -> None:
|
||||
self._results.extend(results)
|
||||
|
||||
async def execute(self, _stmt: Any) -> _Result:
|
||||
self.execute_call_count += 1
|
||||
if not self._results:
|
||||
return _Result(rows=[])
|
||||
return self._results.pop(0)
|
||||
|
||||
def begin_nested(self) -> _FakeNestedCtx:
|
||||
return _FakeNestedCtx()
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
def _enabled_flags() -> AgentFeatureFlags:
|
||||
return AgentFeatureFlags(
|
||||
disable_new_agent_stack=False,
|
||||
enable_action_log=True,
|
||||
enable_revert_route=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_get_flags():
|
||||
def _patch(flags: AgentFeatureFlags):
|
||||
return patch(
|
||||
"app.routes.agent_revert_route.get_flags",
|
||||
return_value=flags,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestFlagGuard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_503_when_revert_route_disabled(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
flags = AgentFeatureFlags(
|
||||
disable_new_agent_stack=False,
|
||||
enable_action_log=True,
|
||||
enable_revert_route=False,
|
||||
)
|
||||
session = _FakeSession()
|
||||
with patch_get_flags(flags), pytest.raises(Exception) as exc:
|
||||
await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="42:1700000000000",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert getattr(exc.value, "status_code", None) == 503
|
||||
|
||||
|
||||
class TestRevertTurnDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=[])) # rows query returns nothing
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-empty",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.total == 0
|
||||
assert response.results == []
|
||||
assert session.committed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_walks_rows_in_reverse_and_reverts_each(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=10, tool_name="rm"),
|
||||
_FakeAction(id=9, tool_name="write_file"),
|
||||
_FakeAction(id=8, tool_name="mkdir"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched ``_was_already_reverted_batch`` probe replaces
|
||||
# the previous N per-row SELECTs.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"reverted-{action.id}",
|
||||
new_action_id=100 + action.id,
|
||||
)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-3",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
assert response.total == 3
|
||||
assert response.reverted == 3
|
||||
assert [r.action_id for r in response.results] == [10, 9, 8]
|
||||
assert all(r.status == "reverted" for r in response.results)
|
||||
assert response.results[0].new_action_id == 110
|
||||
# Only TWO ``execute`` calls regardless of the row count: one
|
||||
# for the rows query, one for the batched
|
||||
# ``_was_already_reverted_batch`` probe. Regression guard
|
||||
# against re-introducing the per-row N+1 lookup.
|
||||
assert session.execute_call_count == 2, (
|
||||
"revert-turn loop must batch idempotency probes; got "
|
||||
f"{session.execute_call_count} execute() calls (expected 2)."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_reverted_rows_are_marked_idempotent(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=5, tool_name="edit_file")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch probe returns ``[(revert_id, original_id)]``.
|
||||
session.queue(_Result(rows=[(42, 5)]))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-i",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.already_reverted == 1
|
||||
assert response.results[0].status == "already_reverted"
|
||||
assert response.results[0].new_action_id == 42
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_action_skips_existing_revert_rows(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-rev",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.results[0].status == "skipped"
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_success_when_some_rows_not_reversible(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=2, tool_name="send_email"),
|
||||
_FakeAction(id=1, tool_name="edit_file"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched idempotency probe.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.tool_name == "send_email":
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="connector revert not yet implemented",
|
||||
)
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-mix",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.reverted == 1
|
||||
assert response.not_reversible == 1
|
||||
statuses = sorted(r.status for r in response.results)
|
||||
assert statuses == ["not_reversible", "reverted"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_marks_row_failed_not_batch(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=20, tool_name="edit_file"),
|
||||
_FakeAction(id=21, tool_name="edit_file"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched idempotency probe.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.id == 20:
|
||||
raise RuntimeError("disk on fire")
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=999)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-fail",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.failed == 1
|
||||
assert response.reverted == 1
|
||||
bad = next(r for r in response.results if r.action_id == 20)
|
||||
assert bad.status == "failed"
|
||||
assert "disk on fire" in (bad.error or "")
|
||||
good = next(r for r in response.results if r.action_id == 21)
|
||||
assert good.status == "reverted"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied_when_other_user_owns_action(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch idempotency probe (no prior reverts).
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-perm",
|
||||
session=session,
|
||||
user=_FakeUser(id="not-owner"),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.results[0].status == "permission_denied"
|
||||
# ``permission_denied`` has its own dedicated counter so the
|
||||
# response invariant ``total == sum(counters)`` always holds
|
||||
# without overloading ``not_reversible`` (which historically
|
||||
# absorbed this case and confused frontend toasts).
|
||||
assert response.permission_denied == 1
|
||||
assert response.not_reversible == 0
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_invariant_holds_across_mixed_outcomes(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
"""Every row is accounted for in EXACTLY ONE counter.
|
||||
|
||||
Mixes one of every supported outcome (reverted, already_reverted,
|
||||
not_reversible, permission_denied, failed, skipped) and asserts
|
||||
that the sum of counters equals ``response.total``.
|
||||
"""
|
||||
rows = [
|
||||
_FakeAction(id=10, tool_name="edit_file"), # ok
|
||||
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
|
||||
_FakeAction(id=8, tool_name="send_email"), # not_reversible
|
||||
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
|
||||
_FakeAction(id=6, tool_name="edit_file"), # failed
|
||||
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched probe; only id=9 has a prior revert.
|
||||
# Schema: list[(revert_id, original_id)].
|
||||
session.queue(_Result(rows=[(42, 9)]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.id == 10:
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||
if action.id == 8:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="connector revert not yet implemented",
|
||||
)
|
||||
if action.id == 6:
|
||||
raise RuntimeError("boom")
|
||||
raise AssertionError(f"unexpected revert call for {action.id}")
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route,
|
||||
"revert_action",
|
||||
AsyncMock(side_effect=_fake_revert),
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-mixed-all",
|
||||
session=session,
|
||||
user=_FakeUser(), # only id=7 has a different user_id
|
||||
)
|
||||
|
||||
assert response.total == len(rows) == 6
|
||||
bucket_sum = (
|
||||
response.reverted
|
||||
+ response.already_reverted
|
||||
+ response.not_reversible
|
||||
+ response.permission_denied
|
||||
+ response.failed
|
||||
+ response.skipped
|
||||
)
|
||||
assert bucket_sum == response.total, (
|
||||
"Counter invariant broken: total "
|
||||
f"({response.total}) != sum of counters ({bucket_sum}). "
|
||||
f"Counters: reverted={response.reverted}, "
|
||||
f"already_reverted={response.already_reverted}, "
|
||||
f"not_reversible={response.not_reversible}, "
|
||||
f"permission_denied={response.permission_denied}, "
|
||||
f"failed={response.failed}, skipped={response.skipped}"
|
||||
)
|
||||
assert response.reverted == 1
|
||||
assert response.already_reverted == 1
|
||||
assert response.not_reversible == 1
|
||||
assert response.permission_denied == 1
|
||||
assert response.failed == 1
|
||||
assert response.skipped == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integrity_error_translates_to_already_reverted(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
"""The partial unique index on ``reverse_of`` raises
|
||||
``IntegrityError`` when a concurrent revert wins the race against
|
||||
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
|
||||
recover by re-querying for the winning revert id and returning
|
||||
``status="already_reverted"`` (not ``"failed"``) so racing
|
||||
clients see consistent idempotent semantics.
|
||||
"""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
rows = [_FakeAction(id=33, tool_name="edit_file")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch pre-flight probe: nothing yet (we'll race).
|
||||
session.queue(_Result(rows=[]))
|
||||
# Post-IntegrityError fallback uses the SCALAR
|
||||
# ``_was_already_reverted`` (single-id lookup) so it pulls
|
||||
# ``[777]`` via ``.scalars().first()``.
|
||||
session.queue(_Result(rows=[777]))
|
||||
|
||||
async def _racing_revert(_session, *, action, requester_user_id):
|
||||
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route,
|
||||
"revert_action",
|
||||
AsyncMock(side_effect=_racing_revert),
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-race",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
|
||||
assert response.failed == 0, (
|
||||
"IntegrityError must NOT surface as a failed row; the unique "
|
||||
"index is the durable expression of idempotency."
|
||||
)
|
||||
assert response.already_reverted == 1
|
||||
assert response.results[0].status == "already_reverted"
|
||||
assert response.results[0].new_action_id == 777
|
||||
|
|
@ -0,0 +1,370 @@
|
|||
"""Unit tests for the filesystem-tool branches of ``revert_service``.
|
||||
|
||||
Covers:
|
||||
|
||||
* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document
|
||||
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
|
||||
prefix-based dispatch).
|
||||
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
|
||||
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
|
||||
when ``metadata_before["virtual_path"]`` is missing.
|
||||
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
|
||||
document.
|
||||
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
|
||||
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
|
||||
when the folder gained children.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.services import revert_service
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
revert_service,
|
||||
"embed_texts",
|
||||
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||
self._rows = rows or []
|
||||
self._scalar = scalar
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self._rows)
|
||||
|
||||
def scalar_one_or_none(self) -> Any:
|
||||
return self._scalar
|
||||
|
||||
def scalars(self) -> Any:
|
||||
return _FakeScalarsProxy(self._rows)
|
||||
|
||||
|
||||
class _FakeScalarsProxy:
|
||||
def __init__(self, rows: list[Any]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = AsyncMock()
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.flush = AsyncMock()
|
||||
# session.get(Model, pk) lookup
|
||||
self.get = AsyncMock(return_value=None)
|
||||
|
||||
async def _flush_assigning_ids() -> None:
|
||||
for obj in self.added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = 999
|
||||
|
||||
self.flush.side_effect = _flush_assigning_ids
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: list[Any]) -> None:
|
||||
self.added.extend(objs)
|
||||
|
||||
|
||||
def _action(*, tool_name: str, action_id: int = 7):
|
||||
return MagicMock(
|
||||
id=action_id,
|
||||
tool_name=tool_name,
|
||||
thread_id=1,
|
||||
search_space_id=2,
|
||||
user_id="user-1",
|
||||
reverse_descriptor=None,
|
||||
)
|
||||
|
||||
|
||||
def _doc_revision(
|
||||
*,
|
||||
document_id: int | None = None,
|
||||
content_before: str | None = "old content",
|
||||
title_before: str | None = "notes.md",
|
||||
folder_id_before: int | None = 5,
|
||||
chunks_before: list[dict[str, str]] | None = None,
|
||||
metadata_before: dict[str, str] | None = None,
|
||||
):
|
||||
revision = MagicMock()
|
||||
revision.id = 100
|
||||
revision.document_id = document_id
|
||||
revision.search_space_id = 2
|
||||
revision.content_before = content_before
|
||||
revision.title_before = title_before
|
||||
revision.folder_id_before = folder_id_before
|
||||
revision.chunks_before = chunks_before or []
|
||||
revision.metadata_before = metadata_before
|
||||
return revision
|
||||
|
||||
|
||||
def _folder_revision(
|
||||
*,
|
||||
folder_id: int | None = None,
|
||||
name_before: str | None = "team",
|
||||
parent_id_before: int | None = None,
|
||||
position_before: str | None = "a0",
|
||||
):
|
||||
revision = MagicMock()
|
||||
revision.id = 200
|
||||
revision.folder_id = folder_id
|
||||
revision.search_space_id = 2
|
||||
revision.name_before = name_before
|
||||
revision.parent_id_before = parent_id_before
|
||||
revision.position_before = position_before
|
||||
return revision
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exact-name dispatch regression guards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExactDispatch:
|
||||
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rmdir_does_not_misroute_to_document(self) -> None:
|
||||
# If dispatch used `startswith("rm")` we'd hit the document branch
|
||||
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
|
||||
session = _FakeSession()
|
||||
action = _action(tool_name="rmdir")
|
||||
# No folder revisions exist for this action.
|
||||
session.execute.return_value = _FakeResult(rows=[])
|
||||
outcome = await revert_service.revert_action(
|
||||
session, # type: ignore[arg-type]
|
||||
action=action,
|
||||
requester_user_id="user-1",
|
||||
)
|
||||
assert outcome.status == "not_reversible"
|
||||
assert "folder_revisions" in outcome.message
|
||||
|
||||
def test_dispatch_sets_split_doc_and_folder(self) -> None:
|
||||
# Static guards on the dispatch tables themselves so a future
|
||||
# refactor doesn't accidentally reintroduce the prefix bug.
|
||||
assert "rm" in revert_service._DOC_TOOLS
|
||||
assert "rmdir" in revert_service._FOLDER_TOOLS
|
||||
assert "rmdir" not in revert_service._DOC_TOOLS
|
||||
assert "rm" not in revert_service._FOLDER_TOOLS
|
||||
# ``move_file`` lives only in document tools (it's a doc rename).
|
||||
assert "move_file" in revert_service._DOC_TOOLS
|
||||
assert "move_file" not in revert_service._FOLDER_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rm revert (re-INSERT)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_inserts_document_with_chunks(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None, # row was hard-deleted
|
||||
content_before="hello world",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
|
||||
metadata_before={"virtual_path": "/documents/x.md"},
|
||||
)
|
||||
# No collision check hit and the resulting query returns nothing.
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
assert outcome.status == "ok"
|
||||
# New Document + 2 chunks must have been added.
|
||||
from app.db import Chunk, Document
|
||||
|
||||
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
|
||||
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
|
||||
assert len(added_docs) == 1
|
||||
assert added_docs[0].title == "x.md"
|
||||
assert len(added_chunks) == 2
|
||||
# Snapshot was repointed at the new doc id so a follow-up revert works.
|
||||
assert revision.document_id == added_docs[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
|
||||
self,
|
||||
) -> None:
|
||||
session = _FakeSession()
|
||||
# Snapshot with NO metadata_before — the fallback path must kick in.
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hello",
|
||||
title_before="cap.md",
|
||||
folder_id_before=42,
|
||||
chunks_before=[],
|
||||
metadata_before=None,
|
||||
)
|
||||
# session.get(Folder, 42) returns a folder with a name.
|
||||
folder = MagicMock()
|
||||
folder.name = "team"
|
||||
folder.parent_id = None
|
||||
# First .get is for the folder lookup in the path-derivation.
|
||||
session.get = AsyncMock(return_value=folder)
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_root_path_when_no_folder(
|
||||
self,
|
||||
) -> None:
|
||||
"""metadata_before is None and folder_id_before is None still
|
||||
resolves: title fallback yields ``/documents/<title>`` so revert
|
||||
proceeds at the root of the documents tree."""
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hello",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
metadata_before=None,
|
||||
)
|
||||
# No collision in the documents tree at /documents/x.md.
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hi",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
metadata_before={"virtual_path": "/documents/x.md"},
|
||||
)
|
||||
# SELECT for unique_identifier_hash collision hits an existing row.
|
||||
session.execute.return_value = _FakeResult(scalar=42)
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "tool_unavailable"
|
||||
assert "collide" in outcome.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file create revert (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWriteFileCreateRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_created_doc(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=99,
|
||||
content_before=None, # marker for "created in this action"
|
||||
title_before=None,
|
||||
)
|
||||
outcome = await revert_service._delete_created_document(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
# Exactly one DELETE was issued.
|
||||
assert session.execute.await_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rmdir revert (re-INSERT folder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmdirRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_inserts_folder_from_snapshot(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(
|
||||
folder_id=None,
|
||||
name_before="team",
|
||||
parent_id_before=None,
|
||||
position_before="a0",
|
||||
)
|
||||
outcome = await revert_service._reinsert_folder_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
from app.db import Folder
|
||||
|
||||
assert outcome.status == "ok"
|
||||
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
|
||||
assert len(added_folders) == 1
|
||||
assert added_folders[0].name == "team"
|
||||
assert revision.folder_id == added_folders[0].id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mkdir revert (DELETE folder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMkdirRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_empty_folder(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(folder_id=42)
|
||||
# Both the doc-existence check and the child-folder check return None.
|
||||
session.execute.side_effect = [
|
||||
_FakeResult(scalar=None), # docs
|
||||
_FakeResult(scalar=None), # children
|
||||
_FakeResult(scalar=None), # delete (no return value)
|
||||
]
|
||||
outcome = await revert_service._delete_created_folder(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
# 3 executes: docs check, children check, delete.
|
||||
assert session.execute.await_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(folder_id=42)
|
||||
# First check (docs) returns "row found".
|
||||
session.execute.return_value = _FakeResult(scalar=1)
|
||||
outcome = await revert_service._delete_created_folder(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "tool_unavailable"
|
||||
assert "no longer empty" in outcome.message
|
||||
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
|
||||
|
||||
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
||||
silently dropped every other shape (Anthropic typed-block lists,
|
||||
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
|
||||
a few providers). These regression tests pin those four shapes plus the
|
||||
defensive cases (``None`` chunk, mixed types, missing fields).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeChunk:
|
||||
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
|
||||
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class TestStringContent:
|
||||
def test_plain_string_content_extracts_as_text(self) -> None:
|
||||
chunk = _FakeChunk(content="hello world")
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "hello world"
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
|
||||
def test_empty_string_content_yields_empty_text(self) -> None:
|
||||
chunk = _FakeChunk(content="")
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == ""
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
|
||||
|
||||
class TestListContent:
|
||||
def test_list_of_text_blocks_concatenates(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "Hello world"
|
||||
assert out["reasoning"] == ""
|
||||
|
||||
def test_mixed_text_and_reasoning_blocks(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "Let me think... "},
|
||||
{"type": "reasoning", "text": "still thinking."},
|
||||
{"type": "text", "text": "The answer is 42."},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "The answer is 42."
|
||||
assert out["reasoning"] == "Let me think... still thinking."
|
||||
|
||||
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Calling tool..."},
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"id": "call_123",
|
||||
"name": "make_widget",
|
||||
"args": '{"color":"red"}',
|
||||
},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "Calling tool..."
|
||||
assert out["reasoning"] == ""
|
||||
assert len(out["tool_call_chunks"]) == 1
|
||||
assert out["tool_call_chunks"][0]["id"] == "call_123"
|
||||
assert out["tool_call_chunks"][0]["name"] == "make_widget"
|
||||
|
||||
def test_tool_use_blocks_also_extracted(self) -> None:
|
||||
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_xyz",
|
||||
"name": "search",
|
||||
},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["tool_call_chunks"] == [
|
||||
{"type": "tool_use", "id": "call_xyz", "name": "search"}
|
||||
]
|
||||
|
||||
def test_unknown_block_types_are_ignored(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "image_url", "url": "https://example.com/x.png"},
|
||||
{"type": "text", "text": "ok"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "ok"
|
||||
|
||||
def test_blocks_without_text_field_are_ignored(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text"}, # no text/content key
|
||||
{"type": "text", "text": "kept"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "kept"
|
||||
|
||||
|
||||
class TestAdditionalKwargsReasoning:
|
||||
def test_reasoning_content_in_additional_kwargs(self) -> None:
|
||||
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
|
||||
chunk = _FakeChunk(
|
||||
content="visible answer",
|
||||
additional_kwargs={"reasoning_content": "internal monologue"},
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "visible answer"
|
||||
assert out["reasoning"] == "internal monologue"
|
||||
|
||||
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[{"type": "reasoning", "text": "from blocks. "}],
|
||||
additional_kwargs={"reasoning_content": "from kwargs."},
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["reasoning"] == "from blocks. from kwargs."
|
||||
|
||||
|
||||
class TestToolCallChunksAttribute:
|
||||
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
|
||||
self,
|
||||
) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content="streaming text",
|
||||
tool_call_chunks=[
|
||||
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
|
||||
],
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "streaming text"
|
||||
assert len(out["tool_call_chunks"]) == 1
|
||||
assert out["tool_call_chunks"][0]["id"] == "tc-9"
|
||||
|
||||
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"id": "from-block",
|
||||
"name": "x",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
|
||||
assert ids == ["from-block", "from-attr"]
|
||||
|
||||
|
||||
class TestDefensive:
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_value",
|
||||
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
|
||||
)
|
||||
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
|
||||
out = _extract_chunk_parts(chunk_value)
|
||||
assert out["text"] == ""
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue