feat: improved agent streaming

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-29 07:20:31 -07:00
parent afb4b09cde
commit c110f5b955
60 changed files with 8068 additions and 303 deletions

View file

@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.tools.registry import ToolDefinition
@dataclass
class _FakeRuntime:
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
to populate the new ``chat_turn_id`` column (see migration 135).
"""
config: dict[str, Any] | None = None
@dataclass
class _FakeRequest:
"""Minimal stand-in for ToolCallRequest used in unit tests."""
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
"args": {"color": "red", "size": 3},
"id": "tc-abc",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
handler = AsyncMock(return_value=result_msg)
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
assert row.error is None
assert row.reverse_descriptor is None
assert row.reversible is False
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
assert row.tool_call_id == "tc-abc"
assert row.turn_id == "tc-abc"
assert row.chat_turn_id == "42:1700000000000"
@pytest.mark.asyncio
async def test_chat_turn_id_none_when_runtime_missing(
self, patch_get_flags, fake_session_factory
) -> None:
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
runtime=None,
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.tool_call_id == "tc-1"
assert row.chat_turn_id is None
@pytest.mark.asyncio
async def test_writes_row_on_failure_and_reraises(
@ -293,6 +333,76 @@ class TestReverseDescriptor:
assert row.reversible is False
class TestActionLogDispatch:
"""Verify ``adispatch_custom_event`` fires after commit."""
@pytest.mark.asyncio
async def test_dispatches_action_log_event_on_success(
self, patch_get_flags, fake_session_factory
) -> None:
_captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "red"},
"id": "tc-evt",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
handler = AsyncMock(return_value=result_msg)
dispatch_mock = AsyncMock()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_awaited_once()
call_args = dispatch_mock.await_args
assert call_args is not None
assert call_args.args[0] == "action_log"
payload = call_args.args[1]
assert payload["lc_tool_call_id"] == "tc-evt"
assert payload["chat_turn_id"] == "42:1700000000000"
assert payload["tool_name"] == "make_widget"
assert payload["reversible"] is False
assert payload["reverse_descriptor_present"] is False
assert payload["error"] is False
@pytest.mark.asyncio
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
"""If commit fails the dispatch is suppressed (no row to surface)."""
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
dispatch_mock = AsyncMock()
def _exploding_session():
raise RuntimeError("DB is down")
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=_exploding_session),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_not_awaited()
class TestArgsTruncation:
@pytest.mark.asyncio
async def test_huge_args_payload_is_truncated(

View file

@ -0,0 +1,122 @@
"""Tests for the desktop-mode safety ruleset.
In desktop mode the agent operates against the user's real disk with no
revision history, so destructive filesystem operations must require
explicit approval. These tests pin the set of tools that get the ``ask``
gate so it cannot silently regress.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
pytestmark = pytest.mark.unit
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
# copy here means the rule contract has a focused regression test even when
# the larger graph-build helper is hard to instantiate in unit tests.
DESKTOP_SAFETY_RULESET = Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
SURFSENSE_DEFAULTS = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
rules = evaluate_many(tool_name, [tool_name], *rulesets)
return aggregate_action(rules)
class TestDesktopSafetyRulesGateDestructiveOps:
@pytest.mark.parametrize(
"tool_name",
["rm", "rmdir", "move_file", "edit_file", "write_file"],
)
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
# surfsense_defaults says "allow */*"; desktop_safety must override
# because it's layered later (last-match-wins).
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask", (
f"{tool_name} must require approval in desktop mode "
f"(no revert path on real disk); got {action!r}"
)
@pytest.mark.parametrize(
"tool_name",
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
)
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
# Read-only and trivially-reversible tools must NOT get gated —
# otherwise every navigation in desktop mode pops an interrupt.
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "allow", (
f"{tool_name} should not be gated in desktop mode; got {action!r}"
)
class TestDesktopSafetyOverridesAllowDefault:
def test_layer_order_last_match_wins(self) -> None:
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
# default would win and the safety net would be inert. This test
# protects against accidentally swapping the rulesets in
# ``_build_compiled_agent_blocking``.
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
# Layered "wrong way" — the broad allow now wins.
assert action == "allow"
# Correct order: defaults < desktop_safety -> ask wins.
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask"
class TestPermissionMiddlewareIntegration:
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
from langchain_core.messages import AIMessage
from app.agents.new_chat.errors import RejectedError
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
# Stub the interrupt to a "reject" decision so we can assert the
# ask path was taken without spinning up the LangGraph runtime.
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
state = {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": "rm",
"args": {"path": "/Users/me/Documents/important.docx"},
"id": "tc-rm",
}
],
)
]
}
class _FakeRuntime:
config: dict = {"configurable": {"thread_id": "test"}}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())

View file

@ -0,0 +1,111 @@
"""Tests for the default auto-approval list in ``hitl.request_approval``.
These pin the policy that low-stakes connector creation tools (drafts,
new-file creates) skip the HITL interrupt by default. Without this set,
every "draft my newsletter" turn used to fire ~3 interrupts before any
useful work happened.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.tools.hitl import (
DEFAULT_AUTO_APPROVED_TOOLS,
HITLResult,
request_approval,
)
pytestmark = pytest.mark.unit
class TestDefaultAutoApprovedToolsList:
def test_set_contains_expected_creation_tools(self) -> None:
# If anyone changes the policy list, we want a single test to
# update so the contract is explicit. Keep this in sync with
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
expected = {
"create_gmail_draft",
"update_gmail_draft",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
def test_set_is_immutable(self) -> None:
# frozenset prevents accidental at-runtime mutation that would
# silently widen the auto-approval surface.
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
# External-broadcast tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
"create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
f"{tool_name} must remain HITL-gated"
)
class TestRequestApprovalAutoBypass:
def test_auto_approved_tool_skips_interrupt(self) -> None:
# No interrupt mock set up — if the function attempted to call
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
# The fact that we get a clean HITLResult proves the bypass.
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
assert isinstance(result, HITLResult)
assert result.rejected is False
assert result.decision_type == "auto_approved"
# Original params are preserved untouched (no user edits possible).
assert result.params == {
"to": "alice@example.com",
"subject": "hi",
"body": "hey",
}
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
# A tool NOT in the default list must reach ``langgraph.interrupt``.
# Outside a runnable context that call raises a RuntimeError —
# which is exactly the signal we want: the bypass did NOT fire.
with pytest.raises(RuntimeError, match="runnable context"):
request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
def test_user_trusted_tools_still_take_precedence(self) -> None:
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
# was checked BEFORE the default list and must keep working
# for tools outside the default list.
result = request_approval(
action_type="mcp_tool_call",
tool_name="my_custom_mcp_tool",
params={"x": 1},
trusted_tools=["my_custom_mcp_tool"],
)
assert result.decision_type == "trusted"
assert result.rejected is False
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
# When trusted_tools is empty and tool is in the default list,
# we should still bypass — proves the order in request_approval.
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={"title": "Plan"},
trusted_tools=[],
)
assert result.decision_type == "auto_approved"

View file

@ -0,0 +1,333 @@
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
The tools build ``Command(update=...)`` payloads that the persistence
middleware applies at end of turn. These tests stub out the backend and
runtime to assert the staging payload shape:
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
* ``KBPostgresBackend`` view-helpers honor staged deletes.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock
import pytest
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
pytestmark = pytest.mark.unit
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = mode
middleware._custom_tool_descriptions = {}
return middleware
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
state = state or {}
state.setdefault("cwd", "/documents")
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
class _KBBackendStub(KBPostgresBackend):
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
We bypass the real ``__init__`` (which expects a runtime + DB session)
and inject just the methods the rm/rmdir tools touch. The class
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
inside the tools happy, which is what gates them from the desktop
code path.
"""
def __init__(self, *, children=None, file_data=None) -> None:
self.als_info = AsyncMock(return_value=children or [])
self._load_file_data = AsyncMock(
return_value=(file_data, 17) if file_data is not None else None
)
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
return _KBBackendStub(children=children, file_data=file_data)
def _bind_backend(middleware, backend):
"""Inject a backend resolver onto the middleware test instance."""
middleware._get_backend = lambda runtime: backend
return backend
# ---------------------------------------------------------------------------
# rm
# ---------------------------------------------------------------------------
class TestRmStaging:
@pytest.mark.asyncio
async def test_stages_delete_and_tombstones_state(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"cwd": "/documents",
"files": {"/documents/notes.md": {"content": ["hello"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
},
tool_call_id="tc-1",
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert hasattr(result, "update"), f"expected Command, got {result!r}"
update = result.update
assert update["pending_deletes"] == [
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
]
assert update["files"] == {"/documents/notes.md": None}
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
@pytest.mark.asyncio
async def test_rejects_documents_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_staged_dirs(self):
m = _make_middleware()
runtime = _runtime(
{
"staged_dirs": ["/documents/team-x"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/team-x", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
assert "rmdir" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_listing(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/foo", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
@pytest.mark.asyncio
async def test_rejects_anonymous_doc(self):
m = _make_middleware()
runtime = _runtime(
{
"kb_anon_doc": {
"path": "/documents/uploaded.xml",
"title": "uploaded",
"content": "",
"chunks": [],
}
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
assert isinstance(result, str)
assert "read-only" in result
@pytest.mark.asyncio
async def test_drops_path_from_dirty_paths(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"files": {"/documents/notes.md": {"content": ["x"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
"dirty_paths": ["/documents/notes.md"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
update = result.update
# First element is _CLEAR sentinel; the rest must NOT contain the
# rm'd path.
dirty = update.get("dirty_paths") or []
assert "/documents/notes.md" not in dirty[1:]
# ---------------------------------------------------------------------------
# rmdir
# ---------------------------------------------------------------------------
class TestRmdirStaging:
@pytest.mark.asyncio
async def test_stages_dir_delete_when_empty_and_db_backed(self):
m = _make_middleware()
backend = _bind_backend(m, _make_backend_stub(children=[]))
# Override _load_file_data to return None (folder, not a file) and
# parent listing to claim the folder exists.
backend._load_file_data = AsyncMock(return_value=None)
backend.als_info = AsyncMock(
side_effect=[
[], # children of /documents/proj
[
{"path": "/documents/proj", "is_dir": True},
], # parent listing
]
)
runtime = _runtime(
{
"cwd": "/documents",
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert update["pending_dir_deletes"] == [
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
]
@pytest.mark.asyncio
async def test_rejects_non_empty(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "not empty" in result
@pytest.mark.asyncio
async def test_unstages_same_turn_mkdir(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[]))
runtime = _runtime(
{
"cwd": "/documents",
"staged_dirs": ["/documents/scratch"],
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/scratch", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert "pending_dir_deletes" not in update
# _CLEAR sentinel + remaining items (in this case, none).
staged_after = update["staged_dirs"]
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
assert "/documents/scratch" not in staged_after[1:]
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rmdir_tool()
for victim in ("/", "/documents"):
result = await tool.coroutine(victim, runtime=runtime)
assert isinstance(result, str)
assert "refusing to rmdir" in result
@pytest.mark.asyncio
async def test_rejects_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_ancestor_of_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj/sub"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_files(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert isinstance(result, str)
assert "is a file" in result
# ---------------------------------------------------------------------------
# KBPostgresBackend view filter
# ---------------------------------------------------------------------------
class TestKBPostgresBackendDeleteFilter:
"""als_info / glob / grep should suppress paths queued for delete."""
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
runtime = SimpleNamespace(state=state)
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
return backend
def test_pending_filesystem_view_returns_deleted_paths(self):
backend = self._make_backend(
{
"pending_deletes": [
{"path": "/documents/x.md", "tool_call_id": "t1"},
],
"pending_dir_deletes": [
{"path": "/documents/d1", "tool_call_id": "t2"},
],
}
)
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
assert "/documents/x.md" in removed
assert "/documents/d1" in deleted_dirs
assert alias == {}
def test_dir_suppressed_covers_descendants(self):
backend = self._make_backend({})
deleted_dirs = {"/documents/d"}
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)

View file

@ -98,10 +98,54 @@ class TestInitialFilesystemState:
state = _initial_filesystem_state()
assert state["cwd"] == "/documents"
assert state["staged_dirs"] == []
assert state["staged_dir_tool_calls"] == {}
assert state["pending_moves"] == []
assert state["pending_deletes"] == []
assert state["pending_dir_deletes"] == []
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0
class TestMultiEditSamePathCoalescing:
"""Multi-edit-same-path turns must coalesce into ONE binding record.
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
tool_call_id that produced the current state on disk. Because
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
edit doesn't append a new path entry — and because
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
overwrite, the LATEST tool_call_id wins. That's the correct behavior
for snapshotting: revert restores to the pre-mutation state, and
multiple back-to-back edits in one turn coalesce into a single
revisible op (the user sees ONE Revert button per turn-per-path,
not N).
"""
def test_dirty_paths_dedupes_repeated_writes(self):
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
# to the same path produce one entry, not two.
first = _add_unique_reducer([], ["/documents/a.md"])
second = _add_unique_reducer(first, ["/documents/a.md"])
assert second == ["/documents/a.md"]
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
# First write tags the path with tcid-1.
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
# Second write to the same path tags it with tcid-2 (latest wins).
merged = _dict_merge_with_tombstones_reducer(
merged, {"/documents/a.md": "tcid-2"}
)
assert merged == {"/documents/a.md": "tcid-2"}
def test_rm_tombstones_dirty_path_tool_call(self):
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
# prevent a stale binding from leaking past the delete.
merged = _dict_merge_with_tombstones_reducer(
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
)
assert merged == {}

View file

@ -0,0 +1,83 @@
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
A full apply/rollback test would require a live Postgres; here we verify
the migration module's static contract:
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
(one for ``document_revisions.document_id``, one for
``folder_revisions.folder_id``).
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
orphaned revisions.
If any of these invariants regress the snapshot/revert pipeline silently
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
migration "down" or never ran it at all.
"""
from __future__ import annotations
import importlib.util
import inspect
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
_MIGRATION_PATH = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "134_relax_revision_fks.py"
)
def _load_migration():
"""Load the migration module by file path (no package import needed)."""
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
assert spec and spec.loader, "could not load migration spec"
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def test_migration_chain_revision_ids() -> None:
module = _load_migration()
# The migration file uses short numeric revision IDs to match the
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
# filename is documentation, not the canonical revision string.
assert getattr(module, "revision", None) == "134"
assert getattr(module, "down_revision", None) == "133"
def test_migration_exposes_upgrade_and_downgrade() -> None:
module = _load_migration()
upgrade = getattr(module, "upgrade", None)
downgrade = getattr(module, "downgrade", None)
assert callable(upgrade), "upgrade() is required"
assert callable(downgrade), "downgrade() is required"
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
module = _load_migration()
src = inspect.getsource(module.upgrade)
assert "document_revisions" in src
assert "folder_revisions" in src
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
# of the migration: snapshots must outlive their parent row.
assert src.count('ondelete="SET NULL"') >= 2
# And the ``document_id`` / ``folder_id`` columns become nullable.
assert "nullable=True" in src
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
module = _load_migration()
src = inspect.getsource(module.downgrade)
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
# Then restore the original CASCADE/NOT NULL contract.
assert src.count('ondelete="CASCADE"') >= 2
assert "nullable=False" in src

View file

@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
assert "/documents/" not in text, f"{name} mentions cloud namespace"
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
def test_cloud_descs_include_rm_and_rmdir(self):
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
assert "rm" in descs and "rmdir" in descs
assert "Deletes a single file" in descs["rm"]
assert "Deletes an empty directory" in descs["rmdir"]
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
def test_desktop_descs_warn_about_irreversibility(self):
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
assert "NOT reversible" in descs["rm"]
assert "NOT reversible" in descs["rmdir"]
def test_sandbox_addendum_appended_when_available(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=True

View file

@ -0,0 +1,309 @@
"""Unit tests for the kb_persistence snapshot helpers.
The full ``commit_staged_filesystem_state`` body exercises a real session
in integration tests; here we verify the building blocks used by the
snapshot/revert pipeline:
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
(regression guard against the N+1 lookup pattern).
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
shape the snapshot helpers consume.
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
the assertions run in milliseconds and don't require Postgres.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat.middleware import kb_persistence
pytestmark = pytest.mark.unit
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
@pytest.mark.asyncio
async def test_find_action_ids_batch_issues_single_query() -> None:
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(id=11, tool_call_id="tc-a"),
MagicMock(id=22, tool_call_id="tc-b"),
MagicMock(id=33, tool_call_id="tc-c"),
]
)
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=1,
tool_call_ids={"tc-a", "tc-b", "tc-c"},
)
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
assert session.execute.await_count == 1, (
"Snapshot binding must batch into ONE query; got "
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
)
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=None,
tool_call_ids={"tc-a"},
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=42,
tool_call_ids=set(),
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
assert session.execute.await_count == 1
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
doc = MagicMock()
doc.content = "body"
doc.title = "notes.md"
doc.folder_id = 7
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
payload = kb_persistence._doc_revision_payload(
doc, chunks_before=[{"content": "x"}]
)
assert payload["title_before"] == "notes.md"
assert payload["folder_id_before"] == 7
assert payload["content_before"] == "body"
assert payload["chunks_before"] == [{"content": "x"}]
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
def test_doc_revision_payload_handles_missing_metadata() -> None:
doc = MagicMock()
doc.content = ""
doc.title = ""
doc.folder_id = None
doc.document_metadata = None
payload = kb_persistence._doc_revision_payload(doc)
assert payload["metadata_before"] is None
@pytest.mark.asyncio
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(content="alpha"),
MagicMock(content="beta"),
]
)
chunks = await kb_persistence._load_chunks_for_snapshot(
session,
doc_id=42, # type: ignore[arg-type]
)
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
# ---------------------------------------------------------------------------
# Deferred reversibility-flip dispatches.
#
# The snapshot helpers used to dispatch ``action_log_updated`` directly
# from inside the SAVEPOINT block. That meant the SSE side-channel
# could tell the UI a row was reversible while the OUTER transaction
# was still pending — and if the outer commit failed, every SAVEPOINT
# rolled back too, leaving the UI in a state inconsistent with
# durable storage. The deferred-dispatch contract fixes that:
#
# • when a ``deferred_dispatches`` list is provided, the helper
# APPENDS the action_id and does NOT dispatch;
# • the caller (``commit_staged_filesystem_state``) flushes the list
# only AFTER ``await session.commit()`` succeeds; on rollback it
# clears the list so nothing is emitted.
# ---------------------------------------------------------------------------
class _NestedCtx:
"""Async context manager mimicking ``session.begin_nested()``."""
async def __aenter__(self) -> _NestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@pytest.mark.asyncio
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 17
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
doc.title = "x.md"
doc.folder_id = None
doc.content = "body"
rev_id = await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=42,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert rev_id == 17
# Inline dispatch must NOT have fired; the action_id is queued.
assert dispatched == []
assert deferred == [42]
@pytest.mark.asyncio
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 7
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
doc.title = "y.md"
doc.folder_id = None
doc.content = "body"
await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=88,
search_space_id=1,
turn_id="t-1",
# No deferred_dispatches arg — fall back to inline dispatch.
)
assert dispatched == [88]
@pytest.mark.asyncio
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock() # _mark_action_reversible calls execute
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 3
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
await kb_persistence._snapshot_folder_pre_mkdir(
session, # type: ignore[arg-type]
folder=folder,
action_id=55,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert dispatched == []
assert deferred == [55]

View file

@ -0,0 +1,139 @@
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
The empty-folder marker is critical UX: without it, the LLM cannot
distinguish a leaf folder containing one document from a leaf folder
that has no descendants at all, and ends up firing ``rmdir`` on
non-empty folders. These tests pin the rendering contract so that
contract cannot silently regress.
"""
from __future__ import annotations
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
class TestComputeNonEmptyFolders:
def test_folder_with_direct_document_is_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths = [
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
]
non_empty = _compute(folder_paths, doc_paths)
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
def test_truly_empty_leaf_folder_is_not_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths: list[str] = []
assert _compute(folder_paths, doc_paths) == set()
def test_documents_propagate_up_to_all_ancestors(self):
folder_paths = [
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
]
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
non_empty = _compute(folder_paths, doc_paths)
assert non_empty == {
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
}
def test_chain_with_subfolders_marks_only_leaf_empty(self):
# POSIX-like semantic: a folder is "empty" only if it has no
# immediate children (docs OR sub-folders). The model needs this
# because parallel ``rmdir`` calls all see the same starting state,
# so trying to rmdir a parent before its children is never safe.
folder_paths = [
f"{DOCUMENTS_ROOT}/X",
f"{DOCUMENTS_ROOT}/X/Y",
f"{DOCUMENTS_ROOT}/X/Y/Z",
]
non_empty = _compute(folder_paths, [])
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
# sub-folder child, so they are non-empty and should NOT carry the
# ``(empty)`` marker.
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
# Mirrors a real DB layout where every intermediate folder is
# materialized in the ``folders`` table.
folder_paths = [
f"{DOCUMENTS_ROOT}/Travel",
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
f"{DOCUMENTS_ROOT}/Travel/Notes",
]
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
non_empty = _compute(folder_paths, doc_paths)
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
class TestFormatTreeRendering:
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
def _render(
self,
folder_paths: list[str],
doc_specs: list[dict],
) -> str:
from app.agents.new_chat.path_resolver import PathIndex
index = PathIndex(
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
)
class _Row:
def __init__(self, **kw):
self.__dict__.update(kw)
docs = [_Row(**spec) for spec in doc_specs]
mw = KnowledgeTreeMiddleware(
search_space_id=1,
filesystem_mode=None, # type: ignore[arg-type]
)
return mw._format_tree(index, docs)
def test_renders_empty_marker_only_for_truly_empty_folders(self):
# Reproduces the failure scenario from the bug report:
# ``Boarding Pass`` is empty (its only doc was just deleted), while
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
# folders are present in the index, mirroring the real DB layout.
folder_paths = [
"/documents/File Upload",
"/documents/File Upload/2026-04-08",
"/documents/File Upload/2026-04-08/Travel",
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
"/documents/File Upload/2026-04-15",
"/documents/File Upload/2026-04-15/Finance",
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
]
tax_returns_folder_id = (
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
+ 1
)
rendered = self._render(
folder_paths=folder_paths,
doc_specs=[
{
"id": 100,
"title": "federal.pdf",
"folder_id": tax_returns_folder_id,
},
],
)
assert "Boarding Pass/ (empty)" in rendered
assert "Tax Returns/ (empty)" not in rendered
# Intermediate ancestors of the doc must NOT be marked empty.
assert "Finance/ (empty)" not in rendered
assert "2026-04-15/ (empty)" not in rendered

View file

@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
assert write.error is not None
assert "parent directory" in write.error
assert not (tmp_path / "tempoo").exists()
def test_local_backend_delete_file_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "delete-me.md").write_text("bye")
res = backend.delete_file("/delete-me.md")
assert res.error is None
assert res.path == "/delete-me.md"
assert not (tmp_path / "delete-me.md").exists()
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "subdir").mkdir()
res = backend.delete_file("/subdir")
assert res.error is not None
assert "directory" in res.error
assert (tmp_path / "subdir").exists()
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
res = backend.delete_file("/nope.md")
assert res.error is not None
assert "not found" in res.error
def test_local_backend_rmdir_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "empty").mkdir()
res = backend.rmdir("/empty")
assert res.error is None
assert res.path == "/empty"
assert not (tmp_path / "empty").exists()
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "withkid").mkdir()
(tmp_path / "withkid" / "child.md").write_text("x")
res = backend.rmdir("/withkid")
assert res.error is not None
assert "not empty" in res.error
assert (tmp_path / "withkid" / "child.md").exists()
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "f.md").write_text("x")
res = backend.rmdir("/f.md")
assert res.error is not None
assert "not a directory" in res.error
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
"""``rmdir /`` MUST fail. The exact error wording comes from
``_resolve_virtual`` (root resolves to outside the sandbox); what
matters is that the call returns an error and does NOT delete the
sandbox root on disk."""
backend = LocalFolderBackend(str(tmp_path))
res = backend.rmdir("/")
assert res.error is not None
assert "Invalid path" in res.error or "root" in res.error
assert tmp_path.exists()

View file

@ -0,0 +1,143 @@
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
The regenerate route's edit-from-position path introduces:
* ``_find_pre_turn_checkpoint_id`` walks LangGraph checkpoint tuples
newest-first and picks the first one whose ``metadata["turn_id"]``
differs from the edited turn. That checkpoint is the rewind target
(state immediately before the edited turn started).
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
with a validator that prevents callers from requesting a revert pass
without specifying which turn to roll back.
These are pure-Python helpers that don't need a live DB, so we exercise
them with a small ``CheckpointTuple``-shaped namespace and direct
schema instantiation.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
from app.schemas.new_chat import RegenerateRequest
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
return SimpleNamespace(
config={"configurable": {"checkpoint_id": checkpoint_id}},
metadata={"turn_id": turn_id} if turn_id is not None else {},
)
class TestFindPreTurnCheckpointId:
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
# Newest-first: T2 is the most-recent turn. The latest non-T2
# checkpoint (cp2) is the rewind target — state immediately
# before T2 began.
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
# Regression for the bug where walking newest-first returned the
# FIRST cp with ``turn_id != target`` — which is one of the
# later-turn checkpoints, NOT the pre-turn boundary. Editing
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
# latest T3 checkpoint (cp6).
tuples = [
_cp("cp6", "T3"),
_cp("cp5", "T3"),
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_none_when_editing_first_turn(self) -> None:
# No pre-turn boundary exists; caller is expected to fall back
# to the oldest checkpoint or special-case "first turn of the
# thread".
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
def test_returns_none_when_only_edited_turn_present(self) -> None:
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
def test_returns_none_for_empty_history(self) -> None:
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
# Checkpoints written before migration 136 have no
# ``metadata.turn_id``. They should be eligible rewind targets
# — they came before the
# edited turn began.
tuples = [
_cp("cp3", "T2"),
SimpleNamespace(
config={"configurable": {"checkpoint_id": "cp2"}},
metadata=None,
),
_cp("cp1", "T1"),
]
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
# then cp3(T2) crosses the boundary -> return cp2.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
# If a checkpoint tuple's ``config["configurable"]`` is missing
# the ``checkpoint_id`` key (corrupt / partial), we keep the
# last known good target instead of crashing.
broken = SimpleNamespace(
config={"configurable": {}}, metadata={"turn_id": "T1"}
)
tuples = [
_cp("cp3", "T2"),
broken,
_cp("cp1", "T1"),
]
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
class TestRegenerateRequestValidation:
def test_revert_actions_requires_from_message_id(self) -> None:
with pytest.raises(Exception) as exc:
RegenerateRequest(
search_space_id=1,
user_query="hi",
revert_actions=True,
)
msg = str(exc.value).lower()
assert "from_message_id" in msg
def test_from_message_id_without_revert_is_allowed(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
)
assert req.from_message_id == 42
assert req.revert_actions is False
def test_revert_actions_with_from_message_id_passes(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
revert_actions=True,
)
assert req.revert_actions is True

View file

@ -0,0 +1,530 @@
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
The per-turn batch revert route walks rows in reverse ``created_at``
order, reverts each independently, and returns a per-action result
list. Partial success is normal the response status
is ``"partial"`` whenever any row could not be reverted, but we never
collapse the whole batch into a 4xx.
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
session, so they exercise the route's dispatch logic without a real DB.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.routes import agent_revert_route
from app.services.revert_service import RevertOutcome
@dataclass
class _FakeAction:
id: int
tool_name: str
user_id: str | None = "u1"
reverse_of: int | None = None
error: dict | None = None
@dataclass
class _FakeUser:
id: str = "u1"
@dataclass
class _ScalarResult:
rows: list[Any]
def first(self) -> Any:
return self.rows[0] if self.rows else None
def all(self) -> list[Any]:
return list(self.rows)
@dataclass
class _Result:
rows: list[Any] = field(default_factory=list)
def scalars(self) -> _ScalarResult:
return _ScalarResult(self.rows)
def all(self) -> list[Any]:
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
# the row-tuple result (no ``.scalars()`` indirection). The
# rows queued for that helper are list[(revert_id, original_id)].
return list(self.rows)
class _FakeNestedCtx:
"""Async context manager that mimics ``session.begin_nested()``.
The route raises a sentinel exception inside this block to roll back
bad rows. We just pass the exception through.
"""
async def __aenter__(self) -> _FakeNestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
# Returning False (or None) propagates the exception; the route
# catches its own sentinel above this layer.
return False
class _FakeSession:
"""Minimal AsyncSession stand-in for the revert-turn route.
Holds a queue of result objects; each ``execute(...)`` pops the next
one. The route calls ``execute`` exactly once per query so this maps
cleanly onto the assertion order of the test.
"""
def __init__(self) -> None:
self._results: list[_Result] = []
self.committed = False
self.rolled_back = False
# Count execute() calls to assert "no N+1 reverts".
self.execute_call_count = 0
def queue(self, *results: _Result) -> None:
self._results.extend(results)
async def execute(self, _stmt: Any) -> _Result:
self.execute_call_count += 1
if not self._results:
return _Result(rows=[])
return self._results.pop(0)
def begin_nested(self) -> _FakeNestedCtx:
return _FakeNestedCtx()
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
def _enabled_flags() -> AgentFeatureFlags:
return AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=True,
)
@pytest.fixture
def patch_get_flags():
def _patch(flags: AgentFeatureFlags):
return patch(
"app.routes.agent_revert_route.get_flags",
return_value=flags,
)
return _patch
class TestFlagGuard:
@pytest.mark.asyncio
async def test_returns_503_when_revert_route_disabled(
self, patch_get_flags
) -> None:
flags = AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=False,
)
session = _FakeSession()
with patch_get_flags(flags), pytest.raises(Exception) as exc:
await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="42:1700000000000",
session=session,
user=_FakeUser(),
)
assert getattr(exc.value, "status_code", None) == 503
class TestRevertTurnDispatch:
@pytest.mark.asyncio
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
session = _FakeSession()
session.queue(_Result(rows=[])) # rows query returns nothing
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-empty",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 0
assert response.results == []
assert session.committed is True
@pytest.mark.asyncio
async def test_walks_rows_in_reverse_and_reverts_each(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=10, tool_name="rm"),
_FakeAction(id=9, tool_name="write_file"),
_FakeAction(id=8, tool_name="mkdir"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched ``_was_already_reverted_batch`` probe replaces
# the previous N per-row SELECTs.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
return RevertOutcome(
status="ok",
message=f"reverted-{action.id}",
new_action_id=100 + action.id,
)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-3",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 3
assert response.reverted == 3
assert [r.action_id for r in response.results] == [10, 9, 8]
assert all(r.status == "reverted" for r in response.results)
assert response.results[0].new_action_id == 110
# Only TWO ``execute`` calls regardless of the row count: one
# for the rows query, one for the batched
# ``_was_already_reverted_batch`` probe. Regression guard
# against re-introducing the per-row N+1 lookup.
assert session.execute_call_count == 2, (
"revert-turn loop must batch idempotency probes; got "
f"{session.execute_call_count} execute() calls (expected 2)."
)
@pytest.mark.asyncio
async def test_already_reverted_rows_are_marked_idempotent(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=5, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch probe returns ``[(revert_id, original_id)]``.
session.queue(_Result(rows=[(42, 5)]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-i",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 42
revert.assert_not_called()
@pytest.mark.asyncio
async def test_revert_action_skips_existing_revert_rows(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
session = _FakeSession()
session.queue(_Result(rows=rows))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-rev",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.results[0].status == "skipped"
revert.assert_not_called()
@pytest.mark.asyncio
async def test_partial_success_when_some_rows_not_reversible(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=2, tool_name="send_email"),
_FakeAction(id=1, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.tool_name == "send_email":
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
return RevertOutcome(status="ok", message="ok", new_action_id=500)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mix",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.reverted == 1
assert response.not_reversible == 1
statuses = sorted(r.status for r in response.results)
assert statuses == ["not_reversible", "reverted"]
@pytest.mark.asyncio
async def test_unexpected_exception_marks_row_failed_not_batch(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=20, tool_name="edit_file"),
_FakeAction(id=21, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 20:
raise RuntimeError("disk on fire")
return RevertOutcome(status="ok", message="ok", new_action_id=999)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-fail",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.failed == 1
assert response.reverted == 1
bad = next(r for r in response.results if r.action_id == 20)
assert bad.status == "failed"
assert "disk on fire" in (bad.error or "")
good = next(r for r in response.results if r.action_id == 21)
assert good.status == "reverted"
@pytest.mark.asyncio
async def test_permission_denied_when_other_user_owns_action(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch idempotency probe (no prior reverts).
session.queue(_Result(rows=[]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-perm",
session=session,
user=_FakeUser(id="not-owner"),
)
assert response.status == "partial"
assert response.results[0].status == "permission_denied"
# ``permission_denied`` has its own dedicated counter so the
# response invariant ``total == sum(counters)`` always holds
# without overloading ``not_reversible`` (which historically
# absorbed this case and confused frontend toasts).
assert response.permission_denied == 1
assert response.not_reversible == 0
revert.assert_not_called()
@pytest.mark.asyncio
async def test_counter_invariant_holds_across_mixed_outcomes(
self, patch_get_flags
) -> None:
"""Every row is accounted for in EXACTLY ONE counter.
Mixes one of every supported outcome (reverted, already_reverted,
not_reversible, permission_denied, failed, skipped) and asserts
that the sum of counters equals ``response.total``.
"""
rows = [
_FakeAction(id=10, tool_name="edit_file"), # ok
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
_FakeAction(id=8, tool_name="send_email"), # not_reversible
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
_FakeAction(id=6, tool_name="edit_file"), # failed
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched probe; only id=9 has a prior revert.
# Schema: list[(revert_id, original_id)].
session.queue(_Result(rows=[(42, 9)]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 10:
return RevertOutcome(status="ok", message="ok", new_action_id=500)
if action.id == 8:
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
if action.id == 6:
raise RuntimeError("boom")
raise AssertionError(f"unexpected revert call for {action.id}")
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_fake_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mixed-all",
session=session,
user=_FakeUser(), # only id=7 has a different user_id
)
assert response.total == len(rows) == 6
bucket_sum = (
response.reverted
+ response.already_reverted
+ response.not_reversible
+ response.permission_denied
+ response.failed
+ response.skipped
)
assert bucket_sum == response.total, (
"Counter invariant broken: total "
f"({response.total}) != sum of counters ({bucket_sum}). "
f"Counters: reverted={response.reverted}, "
f"already_reverted={response.already_reverted}, "
f"not_reversible={response.not_reversible}, "
f"permission_denied={response.permission_denied}, "
f"failed={response.failed}, skipped={response.skipped}"
)
assert response.reverted == 1
assert response.already_reverted == 1
assert response.not_reversible == 1
assert response.permission_denied == 1
assert response.failed == 1
assert response.skipped == 1
@pytest.mark.asyncio
async def test_integrity_error_translates_to_already_reverted(
self, patch_get_flags
) -> None:
"""The partial unique index on ``reverse_of`` raises
``IntegrityError`` when a concurrent revert wins the race against
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
recover by re-querying for the winning revert id and returning
``status="already_reverted"`` (not ``"failed"``) so racing
clients see consistent idempotent semantics.
"""
from sqlalchemy.exc import IntegrityError
rows = [_FakeAction(id=33, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch pre-flight probe: nothing yet (we'll race).
session.queue(_Result(rows=[]))
# Post-IntegrityError fallback uses the SCALAR
# ``_was_already_reverted`` (single-id lookup) so it pulls
# ``[777]`` via ``.scalars().first()``.
session.queue(_Result(rows=[777]))
async def _racing_revert(_session, *, action, requester_user_id):
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_racing_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-race",
session=session,
user=_FakeUser(),
)
assert response.failed == 0, (
"IntegrityError must NOT surface as a failed row; the unique "
"index is the durable expression of idempotency."
)
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 777

View file

@ -0,0 +1,370 @@
"""Unit tests for the filesystem-tool branches of ``revert_service``.
Covers:
* Exact-name dispatch ``rmdir`` does NOT mis-route to the document
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
prefix-based dispatch).
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
when ``metadata_before["virtual_path"]`` is missing.
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
document.
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
when the folder gained children.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from app.services import revert_service
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
revert_service,
"embed_texts",
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
def scalars(self) -> Any:
return _FakeScalarsProxy(self._rows)
class _FakeScalarsProxy:
def __init__(self, rows: list[Any]) -> None:
self._rows = rows
def first(self) -> Any:
return self._rows[0] if self._rows else None
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
self.added: list[Any] = []
self.deleted: list[Any] = []
self.flush = AsyncMock()
# session.get(Model, pk) lookup
self.get = AsyncMock(return_value=None)
async def _flush_assigning_ids() -> None:
for obj in self.added:
if getattr(obj, "id", None) is None:
obj.id = 999
self.flush.side_effect = _flush_assigning_ids
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: list[Any]) -> None:
self.added.extend(objs)
def _action(*, tool_name: str, action_id: int = 7):
return MagicMock(
id=action_id,
tool_name=tool_name,
thread_id=1,
search_space_id=2,
user_id="user-1",
reverse_descriptor=None,
)
def _doc_revision(
*,
document_id: int | None = None,
content_before: str | None = "old content",
title_before: str | None = "notes.md",
folder_id_before: int | None = 5,
chunks_before: list[dict[str, str]] | None = None,
metadata_before: dict[str, str] | None = None,
):
revision = MagicMock()
revision.id = 100
revision.document_id = document_id
revision.search_space_id = 2
revision.content_before = content_before
revision.title_before = title_before
revision.folder_id_before = folder_id_before
revision.chunks_before = chunks_before or []
revision.metadata_before = metadata_before
return revision
def _folder_revision(
*,
folder_id: int | None = None,
name_before: str | None = "team",
parent_id_before: int | None = None,
position_before: str | None = "a0",
):
revision = MagicMock()
revision.id = 200
revision.folder_id = folder_id
revision.search_space_id = 2
revision.name_before = name_before
revision.parent_id_before = parent_id_before
revision.position_before = position_before
return revision
# ---------------------------------------------------------------------------
# Exact-name dispatch regression guards
# ---------------------------------------------------------------------------
class TestExactDispatch:
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
@pytest.mark.asyncio
async def test_rmdir_does_not_misroute_to_document(self) -> None:
# If dispatch used `startswith("rm")` we'd hit the document branch
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
session = _FakeSession()
action = _action(tool_name="rmdir")
# No folder revisions exist for this action.
session.execute.return_value = _FakeResult(rows=[])
outcome = await revert_service.revert_action(
session, # type: ignore[arg-type]
action=action,
requester_user_id="user-1",
)
assert outcome.status == "not_reversible"
assert "folder_revisions" in outcome.message
def test_dispatch_sets_split_doc_and_folder(self) -> None:
# Static guards on the dispatch tables themselves so a future
# refactor doesn't accidentally reintroduce the prefix bug.
assert "rm" in revert_service._DOC_TOOLS
assert "rmdir" in revert_service._FOLDER_TOOLS
assert "rmdir" not in revert_service._DOC_TOOLS
assert "rm" not in revert_service._FOLDER_TOOLS
# ``move_file`` lives only in document tools (it's a doc rename).
assert "move_file" in revert_service._DOC_TOOLS
assert "move_file" not in revert_service._FOLDER_TOOLS
# ---------------------------------------------------------------------------
# rm revert (re-INSERT)
# ---------------------------------------------------------------------------
class TestRmRevert:
@pytest.mark.asyncio
async def test_re_inserts_document_with_chunks(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None, # row was hard-deleted
content_before="hello world",
title_before="x.md",
folder_id_before=None,
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
metadata_before={"virtual_path": "/documents/x.md"},
)
# No collision check hit and the resulting query returns nothing.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# New Document + 2 chunks must have been added.
from app.db import Chunk, Document
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
assert len(added_docs) == 1
assert added_docs[0].title == "x.md"
assert len(added_chunks) == 2
# Snapshot was repointed at the new doc id so a follow-up revert works.
assert revision.document_id == added_docs[0].id
@pytest.mark.asyncio
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
self,
) -> None:
session = _FakeSession()
# Snapshot with NO metadata_before — the fallback path must kick in.
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="cap.md",
folder_id_before=42,
chunks_before=[],
metadata_before=None,
)
# session.get(Folder, 42) returns a folder with a name.
folder = MagicMock()
folder.name = "team"
folder.parent_id = None
# First .get is for the folder lookup in the path-derivation.
session.get = AsyncMock(return_value=folder)
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_falls_back_to_root_path_when_no_folder(
self,
) -> None:
"""metadata_before is None and folder_id_before is None still
resolves: title fallback yields ``/documents/<title>`` so revert
proceeds at the root of the documents tree."""
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="x.md",
folder_id_before=None,
metadata_before=None,
)
# No collision in the documents tree at /documents/x.md.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hi",
title_before="x.md",
folder_id_before=None,
metadata_before={"virtual_path": "/documents/x.md"},
)
# SELECT for unique_identifier_hash collision hits an existing row.
session.execute.return_value = _FakeResult(scalar=42)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "collide" in outcome.message
# ---------------------------------------------------------------------------
# write_file create revert (DELETE)
# ---------------------------------------------------------------------------
class TestWriteFileCreateRevert:
@pytest.mark.asyncio
async def test_deletes_created_doc(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=99,
content_before=None, # marker for "created in this action"
title_before=None,
)
outcome = await revert_service._delete_created_document(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# Exactly one DELETE was issued.
assert session.execute.await_count == 1
# ---------------------------------------------------------------------------
# rmdir revert (re-INSERT folder)
# ---------------------------------------------------------------------------
class TestRmdirRevert:
@pytest.mark.asyncio
async def test_re_inserts_folder_from_snapshot(self) -> None:
session = _FakeSession()
revision = _folder_revision(
folder_id=None,
name_before="team",
parent_id_before=None,
position_before="a0",
)
outcome = await revert_service._reinsert_folder_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
from app.db import Folder
assert outcome.status == "ok"
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
assert len(added_folders) == 1
assert added_folders[0].name == "team"
assert revision.folder_id == added_folders[0].id
# ---------------------------------------------------------------------------
# mkdir revert (DELETE folder)
# ---------------------------------------------------------------------------
class TestMkdirRevert:
@pytest.mark.asyncio
async def test_deletes_empty_folder(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# Both the doc-existence check and the child-folder check return None.
session.execute.side_effect = [
_FakeResult(scalar=None), # docs
_FakeResult(scalar=None), # children
_FakeResult(scalar=None), # delete (no return value)
]
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# 3 executes: docs check, children check, delete.
assert session.execute.await_count == 3
@pytest.mark.asyncio
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# First check (docs) returns "row found".
session.execute.return_value = _FakeResult(scalar=1)
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "no longer empty" in outcome.message

View file

@ -0,0 +1,185 @@
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
Earlier versions only handled ``isinstance(chunk.content, str)`` and
silently dropped every other shape (Anthropic typed-block lists,
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
a few providers). These regression tests pin those four shapes plus the
defensive cases (``None`` chunk, mixed types, missing fields).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
class TestStringContent:
def test_plain_string_content_extracts_as_text(self) -> None:
chunk = _FakeChunk(content="hello world")
out = _extract_chunk_parts(chunk)
assert out["text"] == "hello world"
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
def test_empty_string_content_yields_empty_text(self) -> None:
chunk = _FakeChunk(content="")
out = _extract_chunk_parts(chunk)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestListContent:
def test_list_of_text_blocks_concatenates(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Hello world"
assert out["reasoning"] == ""
def test_mixed_text_and_reasoning_blocks(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "reasoning", "reasoning": "Let me think... "},
{"type": "reasoning", "text": "still thinking."},
{"type": "text", "text": "The answer is 42."},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "The answer is 42."
assert out["reasoning"] == "Let me think... still thinking."
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Calling tool..."},
{
"type": "tool_call_chunk",
"id": "call_123",
"name": "make_widget",
"args": '{"color":"red"}',
},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Calling tool..."
assert out["reasoning"] == ""
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "call_123"
assert out["tool_call_chunks"][0]["name"] == "make_widget"
def test_tool_use_blocks_also_extracted(self) -> None:
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
chunk = _FakeChunk(
content=[
{
"type": "tool_use",
"id": "call_xyz",
"name": "search",
},
]
)
out = _extract_chunk_parts(chunk)
assert out["tool_call_chunks"] == [
{"type": "tool_use", "id": "call_xyz", "name": "search"}
]
def test_unknown_block_types_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "image_url", "url": "https://example.com/x.png"},
{"type": "text", "text": "ok"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "ok"
def test_blocks_without_text_field_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text"}, # no text/content key
{"type": "text", "text": "kept"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "kept"
class TestAdditionalKwargsReasoning:
def test_reasoning_content_in_additional_kwargs(self) -> None:
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
chunk = _FakeChunk(
content="visible answer",
additional_kwargs={"reasoning_content": "internal monologue"},
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "visible answer"
assert out["reasoning"] == "internal monologue"
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
chunk = _FakeChunk(
content=[{"type": "reasoning", "text": "from blocks. "}],
additional_kwargs={"reasoning_content": "from kwargs."},
)
out = _extract_chunk_parts(chunk)
assert out["reasoning"] == "from blocks. from kwargs."
class TestToolCallChunksAttribute:
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
self,
) -> None:
chunk = _FakeChunk(
content="streaming text",
tool_call_chunks=[
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
],
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "streaming text"
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "tc-9"
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
chunk = _FakeChunk(
content=[
{
"type": "tool_call_chunk",
"id": "from-block",
"name": "x",
}
],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
)
out = _extract_chunk_parts(chunk)
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
assert ids == ["from-block", "from-attr"]
class TestDefensive:
@pytest.mark.parametrize(
"chunk_value",
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
)
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
out = _extract_chunk_parts(chunk_value)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []