Merge remote-tracking branch 'upstream/dev' into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-01 00:05:20 +02:00
commit 5d3b8b9ca9
83 changed files with 10514 additions and 638 deletions

View file

@ -136,6 +136,14 @@ jobs:
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
# macOS Developer ID signing + notarization. Only the macos-latest runner
# consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either
# a file path or a base64-encoded .p12 blob — electron-builder auto-detects.
CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }}
CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }}
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
# TrustedSigning PowerShell module. Only populated when signing is enabled.
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,

View file

@ -285,6 +285,14 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_ACTION_LOG=false
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
# content (typed reasoning blocks, tool-input deltas) and propagate the
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
# Schema migrations 135/136 ship unconditionally because they are
# forward-compatible.
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
# Plugins
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names

View file

@ -0,0 +1,139 @@
"""134_relax_revision_fks
Revision ID: 134
Revises: 133
Create Date: 2026-04-29
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
revisions survive the deletes they describe.
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
hard-deleting a document via the ``rm`` tool (and likewise a
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
the snapshot row is wiped at the exact moment we need it most revert
then has nothing to read and the operation becomes irreversible.
Migration:
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
* ``folder_revisions.folder_id``: same treatment.
The ``search_space_id`` FK on both tables is left unchanged (still
``ON DELETE CASCADE``). When a search space is deleted, all documents,
folders, AND their revisions go together that's the correct teardown
story.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy import inspect
from alembic import op
revision: str = "134"
down_revision: str | None = "133"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def _fk_name(bind, table: str, column: str) -> str | None:
"""Return the (single) FK constraint name on ``table.column``, if any."""
inspector = inspect(bind)
for fk in inspector.get_foreign_keys(table):
cols = fk.get("constrained_columns") or []
if cols == [column]:
return fk.get("name")
return None
def upgrade() -> None:
bind = op.get_bind()
# --- document_revisions.document_id -> nullable + SET NULL ---------------
fk_name = _fk_name(bind, "document_revisions", "document_id")
if fk_name:
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
op.alter_column(
"document_revisions",
"document_id",
existing_type=sa.Integer(),
nullable=True,
)
op.create_foreign_key(
"document_revisions_document_id_fkey",
"document_revisions",
"documents",
["document_id"],
["id"],
ondelete="SET NULL",
)
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
if fk_name:
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
op.alter_column(
"folder_revisions",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
)
op.create_foreign_key(
"folder_revisions_folder_id_fkey",
"folder_revisions",
"folders",
["folder_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
bind = op.get_bind()
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
# (any revision whose parent doc/folder has already been deleted).
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
fk_name = _fk_name(bind, "document_revisions", "document_id")
if fk_name:
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
op.alter_column(
"document_revisions",
"document_id",
existing_type=sa.Integer(),
nullable=False,
)
op.create_foreign_key(
"document_revisions_document_id_fkey",
"document_revisions",
"documents",
["document_id"],
["id"],
ondelete="CASCADE",
)
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
if fk_name:
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
op.alter_column(
"folder_revisions",
"folder_id",
existing_type=sa.Integer(),
nullable=False,
)
op.create_foreign_key(
"folder_revisions_folder_id_fkey",
"folder_revisions",
"folders",
["folder_id"],
["id"],
ondelete="CASCADE",
)

View file

@ -0,0 +1,82 @@
"""135_action_log_correlation_ids
Revision ID: 135
Revises: 134
Create Date: 2026-04-29
Action-log correlation-id cleanup.
Background
----------
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
the LangChain ``tool_call.id`` into that column today (see
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
never persisted.
This migration introduces two new, correctly-named columns:
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
* ``chat_turn_id`` (the per-turn correlation id from
``configurable.turn_id`` used by the per-turn ``revert-turn`` route).
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
existing joins keep working. The old ``turn_id`` column is left in place
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
keeps writing it (= ``tool_call_id``) for the same reason.
Indexes
-------
* ``ix_agent_action_log_tool_call_id`` required by
``_find_action_ids_batch`` (was on ``turn_id``).
* ``ix_agent_action_log_chat_turn_id`` required by the
``revert-turn/{chat_turn_id}`` query.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "135"
down_revision: str | None = "134"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"agent_action_log",
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
)
op.add_column(
"agent_action_log",
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
)
op.create_index(
"ix_agent_action_log_tool_call_id",
"agent_action_log",
["tool_call_id"],
)
op.create_index(
"ix_agent_action_log_chat_turn_id",
"agent_action_log",
["chat_turn_id"],
)
op.execute(
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
)
def downgrade() -> None:
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
op.drop_column("agent_action_log", "chat_turn_id")
op.drop_column("agent_action_log", "tool_call_id")

View file

@ -0,0 +1,52 @@
"""136_new_chat_message_turn_id
Revision ID: 136
Revises: 135
Create Date: 2026-04-29
Persist the per-turn correlation id on each chat message.
Background
----------
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
in checkpoint metadata (see
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
support edit-from-arbitrary-position, the regenerate route needs to map
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
this column the mapping doesn't exist anywhere, so regenerate would
have to hardcode the "last 2 messages" rewind heuristic.
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
plus an index. Legacy rows have NULL the regenerate route degrades
gracefully to the reload-last-two heuristic for those.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "136"
down_revision: str | None = "135"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"new_chat_messages",
sa.Column("turn_id", sa.String(length=64), nullable=True),
)
op.create_index(
"ix_new_chat_messages_turn_id",
"new_chat_messages",
["turn_id"],
)
def downgrade() -> None:
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
op.drop_column("new_chat_messages", "turn_id")

View file

@ -0,0 +1,74 @@
"""137_unique_reverse_of_in_action_log
Revision ID: 137
Revises: 136
Create Date: 2026-04-29
Protect ``agent_action_log.reverse_of`` against double inserts. Two
concurrent revert calls (single-action route + the per-turn batch
route, or two batch routes racing) both pass the
``_was_already_reverted`` SELECT and both insert their own
``_revert:*`` rows. The application-level idempotency check is racy
because there's no DB constraint backing it.
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
``IntegrityError`` and the route can translate it to ``"already_reverted"``
deterministically.
The plain ``UniqueConstraint`` flavour can't be used because most
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
and Postgres does treat NULL as distinct in unique indexes but a
partial index is the cleanest expression of intent and works even on
older Postgres releases that distinguish NULL handling.
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
revision: str = "137"
down_revision: str | None = "136"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
_INDEX_NAME = "ux_agent_action_log_reverse_of"
def upgrade() -> None:
# Defensively de-dup any pre-existing double-revert rows before
# adding the unique index. Keeps the OLDEST row (smallest id) and
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
# trail but no longer claim to be the canonical revert. We do NOT
# delete them — operators can still inspect them via /actions.
op.execute(
"""
WITH dups AS (
SELECT id,
reverse_of,
ROW_NUMBER() OVER (
PARTITION BY reverse_of ORDER BY id ASC
) AS rn
FROM agent_action_log
WHERE reverse_of IS NOT NULL
)
UPDATE agent_action_log
SET reverse_of = NULL
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
"""
)
op.create_index(
_INDEX_NAME,
"agent_action_log",
["reverse_of"],
unique=True,
postgresql_where="reverse_of IS NOT NULL",
)
def downgrade() -> None:
op.drop_index(_INDEX_NAME, table_name="agent_action_log")

View file

@ -728,7 +728,8 @@ def _build_compiled_agent_blocking(
repair_mw = None
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
registered_names: set[str] = {t.name for t in tools}
# Tools owned by the standard deepagents middleware stack.
# Tools owned by the standard deepagents middleware stack and the
# SurfSense filesystem extension.
registered_names |= {
"write_todos",
"ls",
@ -739,6 +740,14 @@ def _build_compiled_agent_blocking(
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
repair_mw = ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
@ -767,25 +776,51 @@ def _build_compiled_agent_blocking(
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
# ``glob``, ``web_search`` …) and, on resume, replay the previous
# reject decision into innocent calls.
# 2. ``connector_synthesized`` — deny rules for tools whose required
# connector is not connected to this space. Overrides #1.
# 3. (future) user-defined rules from ``agent_permission_rules`` table
# via the Agent Permissions UI. Loaded last so they override both.
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
# the agent is operating against the user's real disk. Cloud mode
# has full revision-based revert via ``revert_service``, but
# desktop mode hits disk immediately with no undo, so an
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
# ``write_file`` is unrecoverable. This layer is forced on in
# desktop mode regardless of ``enable_permission`` because the
# safety net is non-negotiable.
# 3. ``connector_synthesized`` — deny rules for tools whose required
# connector is not connected to this space. Overrides #1/#2.
# 4. (future) user-defined rules from ``agent_permission_rules`` table
# via the Agent Permissions UI. Loaded last so they override all.
permission_mw: PermissionMiddleware | None = None
if flags.enable_permission and not flags.disable_new_agent_stack:
synthesized = _synthesize_connector_deny_rules(
available_connectors=available_connectors,
enabled_tool_names={t.name for t in tools},
)
permission_mw = PermissionMiddleware(
rulesets=[
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
# Build the middleware whenever it has work to do: either the user
# opted into the rule engine, OR we're in desktop mode and need the
# safety rules unconditionally.
if permission_enabled or is_desktop_fs:
rulesets: list[Ruleset] = [
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
]
if is_desktop_fs:
rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
Ruleset(rules=synthesized, origin="connector_synthesized"),
],
)
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
if permission_enabled:
synthesized = _synthesize_connector_deny_rules(
available_connectors=available_connectors,
enabled_tool_names={t.name for t in tools},
)
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
permission_mw = PermissionMiddleware(rulesets=rulesets)
# ActionLogMiddleware. Off by default until the ``agent_action_log``
# table is migrated. When enabled, persists one row per tool call
@ -942,6 +977,7 @@ def _build_compiled_agent_blocking(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,

View file

@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
Master kill-switch (overrides everything else):
@ -86,6 +87,15 @@ class AgentFeatureFlags:
False # Backend ships before UI; route returns 503 until this flips
)
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
# When OFF the ``stream_new_chat`` task falls back to the str-only
# text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = False
# Plugins
enable_plugin_loader: bool = False
@ -139,6 +149,10 @@ class AgentFeatureFlags:
# Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability

View file

@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
* ``cwd`` current working directory (per-thread checkpointed).
* ``staged_dirs`` pending mkdir requests (cloud only).
* ``staged_dir_tool_calls`` sidecar map ``path -> tool_call_id`` for staged dirs.
* ``pending_moves`` pending move_file requests (cloud only).
* ``pending_deletes`` pending ``rm`` requests (cloud only).
* ``pending_dir_deletes`` pending ``rmdir`` requests (cloud only).
* ``doc_id_by_path`` virtual_path -> Document.id, populated by lazy reads.
* ``dirty_paths`` paths whose state file content differs from DB.
* ``dirty_path_tool_calls`` sidecar map ``path -> latest tool_call_id`` for
dirty paths; used to bind the per-path snapshot to an action_id.
* ``kb_priority`` top-K priority hints rendered into a system message.
* ``kb_matched_chunk_ids`` internal hand-off for matched-chunk highlighting.
* ``kb_anon_doc`` Redis-loaded anonymous document (if any).
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
)
class PendingMove(TypedDict):
"""A staged move_file operation pending end-of-turn commit."""
class PendingMove(TypedDict, total=False):
"""A staged move_file operation pending end-of-turn commit.
``tool_call_id`` is optional for backward compatibility with checkpoints
written before the snapshot/revert pipeline was wired up; new entries
always include it so the persistence body can resolve an action_id.
"""
source: str
dest: str
overwrite: bool
tool_call_id: str
class PendingDelete(TypedDict, total=False):
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
``tool_call_id`` is required for new entries (it's the binding key used
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
:class:`AgentActionLog` row and bind the snapshot to it). Marked
``total=False`` only to tolerate older checkpoint payloads.
"""
path: str
tool_call_id: str
class KbPriorityEntry(TypedDict, total=False):
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
staged_dir_tool_calls: NotRequired[
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
]
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
Kept separate from ``staged_dirs`` (which stays a unique-string list)
to avoid breaking ``_add_unique_reducer`` semantics.
"""
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
"""move_file ops staged for end-of-turn commit (cloud only)."""
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
uniqueness is enforced inside the commit body, not the reducer (we keep
``tool_call_id`` per occurrence so snapshot binding works).
"""
pending_dir_deletes: NotRequired[
Annotated[list[PendingDelete], _list_append_reducer]
]
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
Same shape as :data:`pending_deletes`. Commit body re-verifies the
folder is empty (in-DB AND with this turn's pending changes accounted
for) before issuing the DELETE.
"""
doc_id_by_path: NotRequired[
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
]
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""Paths whose ``state["files"]`` content has been modified this turn."""
dirty_path_tool_calls: NotRequired[
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
]
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
The persistence body coalesces multiple writes/edits to the same path
into one snapshot per turn. This map captures the most-recent
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
to the latest action_id (the one the user is most likely to revert).
"""
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
"""Top-K priority hints rendered as a system message before the user turn."""
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
__all__ = [
"KbAnonDoc",
"KbPriorityEntry",
"PendingDelete",
"PendingMove",
"SurfSenseFilesystemState",
]

View file

@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.callbacks import adispatch_custom_event
from langchain_core.messages import ToolMessage
from app.agents.new_chat.feature_flags import get_flags
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
result=result,
)
tool_call_id = _resolve_tool_call_id(request)
chat_turn_id = _resolve_chat_turn_id(request)
row = AgentActionLog(
thread_id=self._thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
turn_id=_resolve_turn_id(request),
# ``turn_id`` is the deprecated alias of ``tool_call_id``
# kept for one release for safe rollback. New consumers
# should read ``tool_call_id`` directly.
turn_id=tool_call_id,
tool_call_id=tool_call_id,
chat_turn_id=chat_turn_id,
message_id=_resolve_message_id(request),
tool_name=tool_name,
args=args_payload,
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
async with shielded_async_session() as session:
session.add(row)
await session.commit()
row_id = int(row.id) if row.id is not None else None
row_created_at = row.created_at
except Exception:
logger.warning(
"ActionLogMiddleware failed to persist action log row",
exc_info=True,
)
return
# Surface a side-channel SSE event so the chat tool card can
# render a Revert button immediately after the row is durable.
# ``stream_new_chat`` translates this into a
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
try:
await adispatch_custom_event(
"action_log",
{
"id": row_id,
"lc_tool_call_id": tool_call_id,
"chat_turn_id": chat_turn_id,
"tool_name": tool_name,
"reversible": bool(reversible),
"reverse_descriptor_present": reverse_descriptor is not None,
"created_at": row_created_at.isoformat()
if row_created_at
else None,
"error": error_payload is not None,
},
)
except Exception:
logger.debug(
"ActionLogMiddleware failed to dispatch action_log event",
exc_info=True,
)
def _render_reverse(
self,
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
}
def _resolve_turn_id(request: Any) -> str | None:
def _resolve_tool_call_id(request: Any) -> str | None:
"""Return the LangChain ``tool_call.id`` for this request, if any."""
try:
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
return None
# Deprecated alias kept for one release. Old callers and tests treated
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
# lives under ``tool_call_id``. Both resolve to the same value today.
_resolve_turn_id = _resolve_tool_call_id
def _resolve_chat_turn_id(request: Any) -> str | None:
"""Return ``configurable.turn_id`` for this request, if accessible.
``ToolRuntime.config`` is exposed by LangGraph (see
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
lives at ``runtime.config["configurable"]["turn_id"]``.
"""
try:
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
configurable = config.get("configurable")
if not isinstance(configurable, dict):
return None
value = configurable.get("turn_id")
if isinstance(value, str) and value:
return value
except Exception: # pragma: no cover - defensive
pass
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_turn_id(request)
return _resolve_tool_call_id(request)
def _resolve_result_id(result: Any) -> str | None:

View file

@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
- cd(path): change the current working directory.
- pwd(): print the current working directory.
- move_file(source, dest): move/rename a file under `/documents/`.
- rm(path): delete a single file under `/documents/` (no `-r`).
- rmdir(path): delete an empty directory under `/documents/`.
- list_tree(path, max_depth, page_size): recursively list files/folders.
## Persistence Rules
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
`/documents/temp_scratch.md`) are **discarded** at end of turn use this
prefix for any scratch/working content you do NOT want saved.
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
- mkdir/move_file are staged this turn and committed at end of turn alongside
any new/edited documents.
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
turn alongside any new/edited documents. Snapshot/revert is enabled
for every destructive operation when action logging is on.
## Reading Documents Efficiently
@ -176,6 +179,8 @@ directory (`cwd`).
- cd(path): change the current working directory.
- pwd(): print the current working directory.
- move_file(source, dest): move/rename a file.
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
- rmdir(path): delete an empty directory from disk. NOT reversible.
- list_tree(path, max_depth, page_size): recursively list files/folders.
## Workflow Tips
@ -184,6 +189,8 @@ directory (`cwd`).
- For large trees, prefer `list_tree` then `grep` then `read_file` over
brute-force directory traversal.
- Cross-mount moves are not supported.
- Desktop deletes hit disk immediately and cannot be undone via the
agent's revert flow — confirm before calling `rm`/`rmdir`.
"""
)
@ -355,6 +362,42 @@ Notes:
- Parent folders are created as needed.
"""
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
for end-of-turn commit; the row is removed only after the agent's turn
finishes successfully.
Args:
- path: absolute or relative file path. Cannot point at a directory use
`rmdir` for empty folders. Cannot target the root or `/documents`.
Notes:
- The action is reversible via the per-action revert flow when action
logging is enabled.
- The anonymous uploaded document is read-only and cannot be deleted.
"""
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
deletion (`rm -r`) is intentionally NOT supported clear contents with
`rm` first.
Args:
- path: absolute or relative directory path. Cannot target the root,
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
move out first).
Notes:
- Emptiness is evaluated against the post-staged view, so a same-turn
`rm /a/x.md` followed by `rmdir /a` is fine.
- If the directory was added in this same turn via `mkdir` and never
committed, the staged mkdir is dropped instead of issuing a delete.
- The action is reversible via the per-action revert flow when action
logging is enabled.
"""
# --- desktop-only ----------------------------------------------------------
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
@ -421,6 +464,28 @@ Notes:
- Parent folders are created as needed.
"""
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
disk immediately. Desktop deletes are NOT reversible via the agent's
revert flow.
Args:
- path: absolute mount-prefixed file path. Cannot point at a directory
use `rmdir` for empty folders.
"""
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
deletion is NOT supported. The deletion hits disk immediately and is
NOT reversible via the agent's revert flow.
Args:
- path: absolute mount-prefixed directory path. Cannot target the mount
root or any directory containing files/subfolders.
"""
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"""Pick the active-mode description for every filesystem tool."""
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
}
return {
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
}
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
return path.rsplit("/", 1)[-1]
def _is_ancestor_of(candidate: str, target: str) -> bool:
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
``target`` itself is NOT considered an ancestor (use equality for that).
Both paths are assumed to be canonicalised, absolute, and free of
trailing slashes (except the root ``/``).
"""
if not candidate.startswith("/") or not target.startswith("/"):
return False
if candidate == target:
return False
prefix = candidate.rstrip("/") + "/"
return target.startswith(prefix)
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
self.tools.append(self._create_cd_tool())
self.tools.append(self._create_pwd_tool())
self.tools.append(self._create_move_file_tool())
self.tools.append(self._create_rm_tool())
self.tools.append(self._create_rmdir_tool())
self.tools.append(self._create_list_tree_tool())
if self._sandbox_available:
self.tools.append(self._create_execute_code_tool())
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
}
if self._is_cloud():
update["dirty_paths"] = [path]
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
return Command(update=update)
def sync_write_file(
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
}
if self._is_cloud():
update["dirty_paths"] = [path]
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
if doc_id_to_attach is not None:
update["doc_id_by_path"] = {path: doc_id_to_attach}
return Command(update=update)
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
return Command(
update={
"staged_dirs": [validated],
"staged_dir_tool_calls": {
validated: runtime.tool_call_id,
},
"messages": [
ToolMessage(
content=(
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
files_update: dict[str, Any] = {source: None, dest: source_file_data}
update: dict[str, Any] = {
"files": files_update,
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
"pending_moves": [
{
"source": source,
"dest": dest,
"overwrite": False,
"tool_call_id": runtime.tool_call_id,
}
],
"messages": [
ToolMessage(
content=(
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
update["dirty_paths"] = new_dirty
return Command(update=update)
# ------------------------------------------------------------------ tool: rm
def _create_rm_tool(self) -> BaseTool:
tool_description = (
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
)
async def async_rm(
path: Annotated[
str,
"Absolute or relative path to the file to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
if not path or not path.strip():
return "Error: path is required."
target = self._resolve_relative(path, runtime)
try:
validated = validate_path(target)
except ValueError as exc:
return f"Error: {exc}"
if self._is_cloud():
if validated in ("/", DOCUMENTS_ROOT):
return f"Error: refusing to rm '{validated}'."
if not validated.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud rm must target a path under /documents/ "
f"(got '{validated}')."
)
anon = runtime.state.get("kb_anon_doc") or {}
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
return "Error: the anonymous uploaded document is read-only."
# Refuse if the path looks like a directory.
staged_dirs = list(runtime.state.get("staged_dirs") or [])
if validated in staged_dirs:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
)
pending_dir_deletes = list(
runtime.state.get("pending_dir_deletes") or []
)
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
):
return f"Error: '{validated}' is already queued for rmdir."
backend = self._get_backend(runtime)
if isinstance(backend, KBPostgresBackend):
# Detect "is a directory" via `ls`: if the path lists
# children we know it's a folder. Otherwise we still
# need to confirm it's a real file before staging.
children = await backend.als_info(validated)
if children:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
)
# Already queued for delete this turn?
pending_deletes = list(runtime.state.get("pending_deletes") or [])
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_deletes
):
return f"'{validated}' is already queued for deletion."
# Resolve doc_id (best-effort): file in state or DB.
files_state = runtime.state.get("files") or {}
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
resolved_doc_id: int | None = doc_id_by_path.get(validated)
if (
validated not in files_state
and resolved_doc_id is None
and isinstance(backend, KBPostgresBackend)
):
loaded = await backend._load_file_data(validated)
if loaded is None:
return f"Error: file '{validated}' not found."
_, resolved_doc_id = loaded
files_update: dict[str, Any] = {validated: None}
update: dict[str, Any] = {
"pending_deletes": [
{
"path": validated,
"tool_call_id": runtime.tool_call_id,
}
],
"files": files_update,
"doc_id_by_path": {validated: None},
"messages": [
ToolMessage(
content=(
f"Staged delete of '{validated}' (will commit at "
"end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
],
}
# Drop the path from dirty_paths so a same-turn write+rm
# doesn't recreate the doc at commit time.
dirty_paths = list(runtime.state.get("dirty_paths") or [])
if validated in dirty_paths:
new_dirty: list[Any] = [_CLEAR]
for entry in dirty_paths:
if entry != validated:
new_dirty.append(entry)
update["dirty_paths"] = new_dirty
update["dirty_path_tool_calls"] = {validated: None}
return Command(update=update)
# Desktop mode — hit disk immediately.
backend = self._get_backend(runtime)
adelete = getattr(backend, "adelete_file", None)
if not callable(adelete):
return "Error: rm is not supported by the active backend."
res: WriteResult = await adelete(validated)
if res.error:
return res.error
update_desktop: dict[str, Any] = {
"files": {validated: None},
"messages": [
ToolMessage(
content=f"Deleted file '{res.path or validated}'",
tool_call_id=runtime.tool_call_id,
)
],
}
return Command(update=update_desktop)
def sync_rm(
path: Annotated[
str,
"Absolute or relative path to the file to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
return self._run_async_blocking(async_rm(path, runtime))
return StructuredTool.from_function(
name="rm",
description=tool_description,
func=sync_rm,
coroutine=async_rm,
)
# ------------------------------------------------------------------ tool: rmdir
def _create_rmdir_tool(self) -> BaseTool:
tool_description = (
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
)
async def async_rmdir(
path: Annotated[
str,
"Absolute or relative path of the empty directory to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
if not path or not path.strip():
return "Error: path is required."
target = self._resolve_relative(path, runtime)
try:
validated = validate_path(target)
except ValueError as exc:
return f"Error: {exc}"
if self._is_cloud():
if validated in ("/", DOCUMENTS_ROOT):
return f"Error: refusing to rmdir '{validated}'."
if not validated.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud rmdir must target a path under /documents/ "
f"(got '{validated}')."
)
cwd = self._current_cwd(runtime)
if validated == cwd or _is_ancestor_of(validated, cwd):
return (
f"Error: cannot rmdir '{validated}' because the current "
"cwd is at or under it. cd out first."
)
staged_dirs = list(runtime.state.get("staged_dirs") or [])
pending_dir_deletes = list(
runtime.state.get("pending_dir_deletes") or []
)
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
):
return f"'{validated}' is already queued for deletion."
backend = self._get_backend(runtime)
# The path must currently exist either in DB folder paths or
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
# already accounts for pending deletes/moves) to evaluate
# both existence and emptiness against the post-staged view.
exists_in_staged = validated in staged_dirs
children: list[Any] = []
if isinstance(backend, KBPostgresBackend):
children = list(await backend.als_info(validated))
# Detect "is a file" — if als_info returns no children but
# the path is actually a file, we should reject. We use
# _load_file_data to disambiguate file vs missing folder.
if (
isinstance(backend, KBPostgresBackend)
and not children
and not exists_in_staged
):
loaded = await backend._load_file_data(validated)
if loaded is not None:
return (
f"Error: '{validated}' is a file. Use rm to delete files."
)
# Confirm folder exists in DB by checking the parent listing.
parent = posixpath.dirname(validated) or "/"
parent_listing = await backend.als_info(parent)
parent_has_dir = any(
info.get("path") == validated and info.get("is_dir")
for info in parent_listing
)
if not parent_has_dir:
return f"Error: directory '{validated}' not found."
if children:
return (
f"Error: directory '{validated}' is not empty. "
"Remove contents first."
)
# Same-turn mkdir un-stage: drop the staged_dirs entry
# entirely and skip queuing a DB delete (nothing was ever
# committed).
if exists_in_staged:
rest = [d for d in staged_dirs if d != validated]
return Command(
update={
"staged_dirs": [_CLEAR, *rest],
"staged_dir_tool_calls": {validated: None},
"messages": [
ToolMessage(
content=(f"Un-staged directory '{validated}'."),
tool_call_id=runtime.tool_call_id,
)
],
}
)
return Command(
update={
"pending_dir_deletes": [
{
"path": validated,
"tool_call_id": runtime.tool_call_id,
}
],
"messages": [
ToolMessage(
content=(
f"Staged rmdir of '{validated}' (will commit "
"at end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
],
}
)
# Desktop mode — hit disk immediately.
backend = self._get_backend(runtime)
armdir = getattr(backend, "armdir", None)
if not callable(armdir):
return "Error: rmdir is not supported by the active backend."
res: WriteResult = await armdir(validated)
if res.error:
return res.error
return Command(
update={
"messages": [
ToolMessage(
content=f"Deleted directory '{res.path or validated}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
def sync_rmdir(
path: Annotated[
str,
"Absolute or relative path of the empty directory to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
return self._run_async_blocking(async_rmdir(path, runtime))
return StructuredTool.from_function(
name="rmdir",
description=tool_description,
func=sync_rmdir,
coroutine=async_rmdir,
)
# ------------------------------------------------------------------ tool: list_tree
def _create_list_tree_tool(self) -> BaseTool:

View file

@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
def _pending_moves(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_moves") or [])
def _pending_deletes(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_deletes") or [])
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_dir_deletes") or [])
def _kb_anon_doc(self) -> dict[str, Any] | None:
anon = self.state.get("kb_anon_doc")
return anon if isinstance(anon, dict) else None
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
return path
return path.rstrip("/") if path != "/" else path
def _moved_view_paths(
def _pending_filesystem_view(
self,
existing: dict[str, dict[str, Any]],
) -> tuple[set[str], dict[str, str]]:
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
) -> tuple[set[str], dict[str, str], set[str]]:
"""Compute removed/aliased/dir-suppressed paths from staged ops.
Removed paths should disappear from listings; ``alias[source] = dest``
means a virtual entry should appear at ``dest`` even if no DB row is
yet there.
Returns ``(removed, alias, deleted_dirs)`` where:
* ``removed`` paths to drop from listings (sources of pending moves
AND paths queued for ``rm``).
* ``alias`` ``{source: dest}`` for pending moves; the dest should
appear as a virtual entry even when no DB row is at that path yet.
* ``deleted_dirs`` folder paths queued for ``rmdir``; their entire
subtree (descendants) is suppressed from listings/glob/grep.
Entries in ``existing`` (the ``files`` state cache) keyed by a
removed path are popped so a same-turn delete-after-write doesn't
leave a stale virtual file in listings.
"""
removed: set[str] = set()
alias: dict[str, str] = {}
deleted_dirs: set[str] = set()
for move in self._pending_moves():
src = move.get("source")
dst = move.get("dest")
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
removed.add(src)
alias[src] = dst
existing.pop(src, None)
return removed, alias
for entry in self._pending_deletes():
path = entry.get("path") if isinstance(entry, dict) else None
if not path:
continue
removed.add(path)
existing.pop(path, None)
for entry in self._pending_dir_deletes():
path = entry.get("path") if isinstance(entry, dict) else None
if not path:
continue
deleted_dirs.add(path)
return removed, alias, deleted_dirs
@staticmethod
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
return any(path == d or _is_under(path, d) for d in deleted_dirs)
# ------------------------------------------------------------------ ls/read
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
seen.add(anon_path)
files = self._state_files()
moved_removed, moved_alias = self._moved_view_paths(files)
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
try:
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
for info in db_infos:
p = info.get("path", "")
if not p or p in seen or p in moved_removed:
if (
not p
or p in seen
or p in moved_removed
or self._is_dir_suppressed(p, deleted_dirs)
):
continue
infos.append(info)
seen.add(p)
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
if src not in seen:
if not _is_under(dst, normalized):
continue
if self._is_dir_suppressed(dst, deleted_dirs):
continue
rel = (
dst[len(normalized) :].lstrip("/")
if normalized != "/"
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
continue
if not _is_under(staged, normalized):
continue
if self._is_dir_suppressed(staged, deleted_dirs):
continue
rel = (
staged[len(normalized) :].lstrip("/")
if normalized != "/"
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
for sub in sorted(subdir_paths):
if sub in seen:
continue
if self._is_dir_suppressed(sub, deleted_dirs):
continue
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
seen.add(sub)
for path_key, fd in files.items():
if not isinstance(path_key, str) or path_key in seen:
continue
# Tombstones (None values) are deletion markers from `rm`. The
# deepagents reducer normally pops them, but a stale tombstone
# surviving a checkpoint must NOT be reported as a child here —
# otherwise rmdir mistakenly sees the deleted file as content.
if fd is None:
continue
if not _is_under(path_key, normalized) or path_key == normalized:
continue
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if normalized == "/":
rel = path_key.lstrip("/")
else:
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
seen: set[str] = set()
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
regex = re.compile(fnmatch.translate(pattern))
for path_key, fd in files.items():
if path_key in moved_removed:
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if not _is_under(path_key, normalized):
continue
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
folder_id=row.folder_id,
index=index,
)
if candidate in seen or candidate in moved_removed:
if (
candidate in seen
or candidate in moved_removed
or self._is_dir_suppressed(candidate, deleted_dirs)
):
continue
if not _is_under(candidate, normalized):
continue
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
matches: list[GrepMatch] = []
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
for path_key, fd in files.items():
if path_key in moved_removed:
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if not _is_under(path_key, normalized):
continue
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
)
for doc_id, chunk_id, content in chunk_buffer:
candidate = doc_id_to_path.get(doc_id)
if not candidate or candidate in moved_removed:
if (
not candidate
or candidate in moved_removed
or self._is_dir_suppressed(candidate, deleted_dirs)
):
continue
if not _is_under(candidate, normalized):
continue
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
return {"entries": [], "truncated": False}
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
anon = self._kb_anon_doc()
anon_path = str(anon.get("path") or "") if anon else ""
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
if not _is_under(fpath, normalized):
continue
if self._is_dir_suppressed(fpath, deleted_dirs):
continue
depth = _depth_of(fpath)
if max_depth is not None and depth > max_depth:
continue
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
for staged in self._staged_dirs():
if not _is_under(staged, normalized):
continue
if self._is_dir_suppressed(staged, deleted_dirs):
continue
depth = _depth_of(staged)
if max_depth is not None and depth > max_depth:
continue
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
folder_id=row.folder_id,
index=index,
)
if candidate in moved_removed:
if candidate in moved_removed or self._is_dir_suppressed(
candidate, deleted_dirs
):
continue
if not _is_under(candidate, normalized):
continue
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
continue
if not _is_under(path_key, normalized):
continue
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if any(e["path"] == path_key for e in entries):
continue
if not (

View file

@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
# Pre-compute which folders have at least one descendant (folder or doc).
# A folder is "empty" iff no path in `all_paths` is strictly under it.
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
# infer emptiness from indentation alone.
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
lines: list[str] = []
for path in all_paths:
depth = (
@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
)
if is_dir:
lines.append(f"{indent}{display}/")
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
lines.append(f"{indent}{display}/ (empty)")
else:
lines.append(f"{indent}{display}/")
else:
lines.append(f"{indent}{display}")
if len(lines) >= self.max_entries:
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
return self._format_root_summary(folder_paths, doc_paths)
@staticmethod
def _compute_non_empty_folders(
folder_paths: list[str], doc_paths: list[str]
) -> set[str]:
"""Return the set of folder paths that contain at least one descendant.
A folder is "non-empty" if any document path or any other folder path
is strictly under it. Documents propagate emptiness up to every
ancestor folder, while a sub-folder only marks its direct ancestors
non-empty (so a chain of empty folders all read ``(empty)``).
"""
non_empty: set[str] = set()
folder_set = set(folder_paths)
for doc_path in doc_paths:
parent = doc_path.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT:
if parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
for child in folder_paths:
parent = child.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
return non_empty
def _format_root_summary(
self, folder_paths: list[str], doc_paths: list[str]
) -> str:

View file

@ -360,6 +360,74 @@ class LocalFolderBackend:
self.move, source_path, destination_path, overwrite
)
def delete_file(self, file_path: str) -> WriteResult:
"""Hard-delete a single file under root.
Refuses directories, root, and missing paths. Roughly mirrors POSIX
``rm path``; ``-r`` recursion and glob expansion are explicitly
out of scope.
"""
try:
path = self._resolve_virtual(file_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{file_path}'")
with self._lock_for(file_path):
if not path.exists():
return WriteResult(error=f"Error: File '{file_path}' not found")
if path.is_dir():
return WriteResult(
error=(
f"Error: '{file_path}' is a directory. "
"Use rmdir for empty directories."
)
)
try:
os.unlink(path)
except OSError as exc:
return WriteResult(
error=f"Error: failed to delete '{file_path}': {exc}"
)
return WriteResult(path=file_path, files_update=None)
async def adelete_file(self, file_path: str) -> WriteResult:
return await asyncio.to_thread(self.delete_file, file_path)
def rmdir(self, dir_path: str) -> WriteResult:
"""Hard-delete an empty directory under root.
Refuses files, root, missing paths, and non-empty directories.
``os.rmdir`` is naturally empty-only; we pre-check so the error is
clearer for the agent.
"""
try:
path = self._resolve_virtual(dir_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
with self._lock_for(dir_path):
if not path.exists():
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
if not path.is_dir():
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
try:
next(path.iterdir())
except StopIteration:
pass
else:
return WriteResult(
error=(
f"Error: directory '{dir_path}' is not empty. "
"Remove its contents first."
)
)
try:
os.rmdir(path)
except OSError as exc:
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
return WriteResult(path=dir_path, files_update=None)
async def armdir(self, dir_path: str) -> WriteResult:
return await asyncio.to_thread(self.rmdir, dir_path)
def edit(
self,
file_path: str,

View file

@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
overwrite,
)
def delete_file(self, file_path: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].delete_file(local_path)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def adelete_file(self, file_path: str) -> WriteResult:
return await asyncio.to_thread(self.delete_file, file_path)
def rmdir(self, dir_path: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(dir_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
if local_path == "/":
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
result = self._mount_to_backend[mount].rmdir(local_path)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def armdir(self, dir_path: str) -> WriteResult:
return await asyncio.to_thread(self.rmdir, dir_path)
def edit(
self,
file_path: str,

View file

@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
return {
"cwd": "/documents",
"staged_dirs": [],
"staged_dir_tool_calls": {},
"pending_moves": [],
"pending_deletes": [],
"pending_dir_deletes": [],
"doc_id_by_path": {},
"dirty_paths": [],
"dirty_path_tool_calls": {},
"kb_priority": [],
"kb_matched_chunk_ids": {},
"kb_anon_doc": None,

View file

@ -90,6 +90,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
"write_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"update_memory",
"update_memory_team",
"update_memory_private",

View file

@ -30,6 +30,35 @@ from langgraph.types import interrupt
logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the
# SurfSense KB: each call creates ONE artifact in the user's own workspace
# with no external visibility (drafts aren't sent; new files aren't shared
# unless the user shares them later). These are auto-approved by default
# so the agent can compose drafts and seed scratch files without a popup
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
)
@dataclass(frozen=True, slots=True)
class HITLResult:
"""Outcome of a human-in-the-loop approval request."""
@ -119,6 +148,19 @@ def request_approval(
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name,
)
return HITLResult(
rejected=False, decision_type="auto_approved", params=dict(params)
)
approval = interrupt(
{
"type": action_type,

View file

@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
index=True,
)
# Per-turn correlation id sourced from ``configurable.turn_id`` at
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
# predate the column. Used by C1's edit-from-arbitrary-position to map
# a message back to the LangGraph checkpoint that produced its turn.
turn_id = Column(String(64), nullable=True, index=True)
# Relationships
thread = relationship("NewChatThread", back_populates="messages")
author = relationship("User")
@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel):
nullable=False,
index=True,
)
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
# been renamed to ``tool_call_id`` (with a parallel column kept for one
# release for back-compat). The real chat-turn id lives in
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
turn_id = Column(String(64), nullable=True, index=True)
tool_call_id = Column(String(64), nullable=True, index=True)
chat_turn_id = Column(String(64), nullable=True, index=True)
message_id = Column(String(128), nullable=True, index=True)
tool_name = Column(String(255), nullable=False, index=True)
args = Column(JSONB, nullable=True)
@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel):
__table_args__ = (
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
# Partial unique index enforces "at most one revert per
# original action". Created in migration 137 with
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
# (the vast majority) are unaffected and NULLs don't collide.
Index(
"ux_agent_action_log_reverse_of",
"reverse_of",
unique=True,
postgresql_where=text("reverse_of IS NOT NULL"),
),
)
@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel):
__tablename__ = "document_revisions"
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
# hard-delete it describes — without that, ``rm`` would wipe the row
# we'd need to undo it. See migration ``134_relax_revision_fks``.
document_id = Column(
Integer,
ForeignKey("documents.id", ondelete="CASCADE"),
nullable=False,
ForeignKey("documents.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
search_space_id = Column(
@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel):
__tablename__ = "folder_revisions"
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
# hard-delete it describes — without that, ``rmdir`` would wipe the
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
folder_id = Column(
Integer,
ForeignKey("folders.id", ondelete="CASCADE"),
nullable=False,
ForeignKey("folders.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
search_space_id = Column(

View file

@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
reverse_of: int | None
reverted_by_action_id: int | None
is_revert_action: bool
# Correlation ids added in migration 135. ``tool_call_id`` is the
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
# from ``configurable.turn_id`` (used by the
# ``revert-turn/{chat_turn_id}`` endpoint).
tool_call_id: str | None = None
chat_turn_id: str | None = None
created_at: datetime
@ -172,6 +179,8 @@ async def list_thread_actions(
reverse_of=row.reverse_of,
reverted_by_action_id=revert_map.get(row.id),
is_revert_action=row.reverse_of is not None,
tool_call_id=row.tool_call_id,
chat_turn_id=row.chat_turn_id,
created_at=row.created_at,
)
for row in rows

View file

@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
5. Idempotent on retries: if the same action is reverted twice the second
call returns 409 ``"already reverted"``.
This module also hosts the per-turn batch endpoint
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
walks every reversible action emitted during a chat turn in reverse
``created_at`` order and reverts each independently. Partial success is the
common case the response always contains a per-action result list and a
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
whole-batch 4xx.
"""
from __future__ import annotations
import logging
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.feature_flags import get_flags
@ -97,6 +108,16 @@ async def revert_agent_action(
action=action,
requester_user_id=str(user.id) if user is not None else None,
)
except IntegrityError:
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
# a concurrent revert. Translate to the existing 409 "already
# reverted" contract so racing clients see consistent
# behaviour with the pre-flight TOCTOU check above.
await session.rollback()
raise HTTPException(
status_code=409,
detail="This action has already been reverted.",
) from None
except Exception as err:
logger.exception("Revert dispatch raised for action_id=%s", action_id)
await session.rollback()
@ -105,7 +126,16 @@ async def revert_agent_action(
) from err
if outcome.status == "ok":
await session.commit()
try:
await session.commit()
except IntegrityError:
# Race lost on commit (constraint enforced at flush in some
# configs but at commit in others — defensive).
await session.rollback()
raise HTTPException(
status_code=409,
detail="This action has already been reverted.",
) from None
return {
"status": "ok",
"message": outcome.message,
@ -122,3 +152,357 @@ async def revert_agent_action(
raise HTTPException(status_code=501, detail=outcome.message)
# not_reversible
raise HTTPException(status_code=409, detail=outcome.message)
# ---------------------------------------------------------------------------
# Per-turn revert batch endpoint
# ---------------------------------------------------------------------------
PerActionStatus = Literal[
"reverted",
"already_reverted",
"not_reversible",
"permission_denied",
"failed",
"skipped",
]
class RevertTurnActionResult(BaseModel):
"""Per-action outcome inside a ``revert-turn`` batch response."""
action_id: int
tool_name: str
status: PerActionStatus
message: str | None = None
new_action_id: int | None = None
error: str | None = None
class RevertTurnResponse(BaseModel):
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
``status`` is ``"ok"`` only when every reversible row succeeded. Any
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
``results`` list callers should treat that as a no-op.
Counter invariant:
``total == reverted + already_reverted + not_reversible
+ permission_denied + failed + skipped``
Frontend toasts and the ``RevertTurnButton`` summary rely on this
invariant to display "X of Y reverted, Z could not be undone" without
silently dropping ``permission_denied`` or ``skipped`` rows.
"""
status: Literal["ok", "partial"]
chat_turn_id: str
total: int
reverted: int
already_reverted: int
not_reversible: int
permission_denied: int = 0
failed: int = 0
skipped: int = 0
results: list[RevertTurnActionResult]
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
if outcome.status == "ok":
return "reverted"
if outcome.status == "permission_denied":
return "permission_denied"
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
# — they share the same UX (this row cannot be undone) and only the
# ``message`` differs.
return "not_reversible"
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
"""Return the id of an existing successful revert row, if any.
Single-action variant kept for the post-IntegrityError lookup
path where we already know we lost a race for one specific id.
"""
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
result = await session.execute(stmt)
return result.scalars().first()
async def _was_already_reverted_batch(
session: AsyncSession, *, action_ids: list[int]
) -> dict[int, int]:
"""Batch idempotency probe for the revert-turn loop.
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
(one per row in the turn) with a single ``SELECT id, reverse_of
WHERE reverse_of IN (:ids)``. The route still iterates rows in
reverse-chronological order, but the membership check is O(1) per
iteration after this query. For a turn with 30 actions that's 30
fewer round-trips through asyncpg + a smaller transaction footprint.
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
keys mean "not yet reverted" callers should treat them as
eligible for revert.
"""
if not action_ids:
return {}
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
AgentActionLog.reverse_of.in_(action_ids)
)
result = await session.execute(stmt)
return {
original_id: revert_id
for revert_id, original_id in result.all()
if original_id is not None
}
@router.post(
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
response_model=RevertTurnResponse,
)
async def revert_agent_turn(
thread_id: int,
chat_turn_id: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> RevertTurnResponse:
"""Revert every reversible action emitted during ``chat_turn_id``.
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
folder) unwind in the right sequence. Each action is reverted in its
own SAVEPOINT so a single failure does not poison the batch.
Partial success is intentional and returned with HTTP 200. Callers
must inspect ``results[*].status`` to find rows that need attention.
"""
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:
raise HTTPException(
status_code=503,
detail=(
"Revert is not available on this deployment yet. The route "
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
"enable it."
),
)
thread = await load_thread(session, thread_id=thread_id)
if thread is None:
raise HTTPException(status_code=404, detail="Thread not found.")
# Reverse-chronological so the latest mutation in the turn unwinds
# first. ``id.desc()`` is the deterministic tiebreaker for actions
# written in the same millisecond.
rows_stmt = (
select(AgentActionLog)
.where(
AgentActionLog.thread_id == thread_id,
AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
)
rows = (await session.execute(rows_stmt)).scalars().all()
requester_user_id = str(user.id) if user is not None else None
results: list[RevertTurnActionResult] = []
# Counters MUST be exhaustive so the response invariant
# ``total == sum(counters)`` always holds. Frontend toasts and
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
counts: dict[str, int] = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Single batched idempotency probe replaces the previous per-row
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
# the original-action ids (skip rows that are themselves
# reverts).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
# Skip rows that ARE reverts of an earlier action — reverting a
# revert is meaningless inside a batch (the user wants to wipe
# the original effects, not chase tail).
if action.reverse_of is not None:
counts["skipped"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
)
)
continue
# Idempotency: surface "already_reverted" instead of failing.
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
)
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
)
)
continue
# Per-row SAVEPOINT so one failed revert never poisons later
# successful ones.
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
)
)
continue
except IntegrityError:
# Partial unique index caught a concurrent revert that won
# the race against our pre-flight ``_was_already_reverted``
# SELECT. Look up the winner so
# we can surface its ``new_action_id`` to the client.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
)
)
continue
except Exception as err: # pragma: no cover — defensive, logged
logger.exception(
"Unexpected revert failure inside batch for action_id=%s",
action.id,
)
counts["failed"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
)
)
continue
counts["reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
)
)
# Single commit at the end — successful SAVEPOINTs above already
# released; failed ones rolled back to their savepoint. No row leaks
# across the boundary.
try:
await session.commit()
except Exception as err: # pragma: no cover — defensive
logger.exception(
"Final commit for revert-turn failed (thread=%s turn=%s)",
thread_id,
chat_turn_id,
)
await session.rollback()
raise HTTPException(
status_code=500,
detail="Internal error while finalising revert-turn batch.",
) from err
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
return RevertTurnResponse(
status=overall_status,
chat_turn_id=chat_turn_id,
total=len(rows),
reverted=counts["reverted"],
already_reverted=counts["already_reverted"],
not_reversible=counts["not_reversible"],
permission_denied=counts["permission_denied"],
failed=counts["failed"],
skipped=counts["skipped"],
results=results,
)
class _OutcomeRollbackError(Exception):
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
``revert_action`` writes a new ``agent_action_log`` row only on the
happy path, but on the failure paths it sometimes mutates the
``DocumentRevision``/``Document`` tables before deciding the action
is not reversible. Wrapping each call in ``begin_nested`` and raising
this from the failure branch ensures we always discard partial
writes for failed rows.
"""
def __init__(self, outcome: RevertOutcome) -> None:
self.outcome = outcome
super().__init__(outcome.message)
__all__ = ["router"]

View file

@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
"""
import asyncio
import json
import logging
from datetime import UTC, datetime
@ -136,6 +137,260 @@ def _resolve_filesystem_selection(
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
turn_id: str,
) -> str | None:
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
``checkpoint_tuples`` arrives newest-first from
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
and remember the most recent checkpoint that does NOT belong to the
edited turn. As soon as we cross into the edited turn (a checkpoint
whose ``turn_id`` matches), we return the previously-tracked
checkpoint that's the state immediately before ``turn_id`` began.
The naive "newest-first, return first non-matching" approach is
INCORRECT when later turns exist after ``turn_id``: their
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
returned before the real pre-turn boundary is reached.
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
``configurable`` at write time) rather than ``config["configurable"]``
so the lookup is portable across checkpointer implementations.
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
the edited turn is the very first turn of the thread). Callers fall
back to the oldest available checkpoint in that case.
"""
last_pre_turn_target: str | None = None
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
metadata = getattr(cp_tuple, "metadata", None) or {}
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
if cp_turn_id == turn_id:
# Crossed into the edited turn; the previous tracked
# checkpoint is the rewind target. May be ``None`` if we hit
# the edited turn on the very first iteration.
return last_pre_turn_target
try:
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
except (KeyError, TypeError):
continue
return last_pre_turn_target
async def _revert_turns_for_regenerate(
*,
thread_id: int,
chat_turn_ids: list[str],
requester_user_id: str,
) -> dict:
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
Runs BEFORE the regenerate stream so the frontend can surface
partial-rollback feedback alongside the new assistant turn. Each
turn's actions are reverted in their own SAVEPOINTs (handled
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
failure never poisons the batch.
Sequencing inside the request: revert THEN regenerate. The
operation is NOT atomic and partial state IS surfaced see the
plan's "Sequencing inside the request" note.
"""
from app.routes.agent_revert_route import (
RevertTurnActionResult,
_classify_outcome,
_OutcomeRollbackError,
_was_already_reverted,
_was_already_reverted_batch,
)
from app.services.revert_service import (
can_revert,
revert_action,
)
aggregated_results: list[dict] = []
# Exhaustive counters keep the response invariant
# ``total == sum(counters)`` true for ``data-revert-results``.
counts = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Local import keeps the route module's existing imports tidy and
# avoids a circular dependency at module-load time.
from app.db import AgentActionLog as _AgentActionLog
async with shielded_async_session() as session:
for chat_turn_id in chat_turn_ids:
rows_stmt = (
select(_AgentActionLog)
.where(
_AgentActionLog.thread_id == thread_id,
_AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(
_AgentActionLog.created_at.desc(),
_AgentActionLog.id.desc(),
)
)
rows = (await session.execute(rows_stmt)).scalars().all()
# Batch idempotency probe across the turn (single SELECT
# instead of one per row).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
if action.reverse_of is not None:
counts["skipped"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
).model_dump()
)
continue
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
).model_dump()
)
continue
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
).model_dump()
)
continue
except IntegrityError:
# Concurrent revert won the race against the
# pre-flight ``_was_already_reverted`` SELECT.
# Surface the winning revert id so the client can
# treat this as a successful idempotent op.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
except Exception as err: # pragma: no cover — defensive
_logger.exception(
"Unexpected revert failure during regenerate batch "
"for action_id=%s",
action.id,
)
counts["failed"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
).model_dump()
)
continue
counts["reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
).model_dump()
)
try:
await session.commit()
except Exception:
_logger.exception(
"[regenerate-revert] Final commit failed; rolling back batch."
)
await session.rollback()
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
return {
"status": "partial" if has_partial else "ok",
"chat_turn_ids": chat_turn_ids,
"total": len(aggregated_results),
"reverted": counts["reverted"],
"already_reverted": counts["already_reverted"],
"not_reversible": counts["not_reversible"],
"permission_denied": counts["permission_denied"],
"failed": counts["failed"],
"skipped": counts["skipped"],
"results": aggregated_results,
}
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
@ -574,6 +829,7 @@ async def get_thread_messages(
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
if msg.token_usage
else None,
turn_id=msg.turn_id,
)
for msg in db_messages
]
@ -1006,12 +1262,24 @@ async def append_message(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Create message
# Create message. ``turn_id`` is the per-turn correlation id from
# ``configurable.turn_id`` (added in migration 136) — when the
# client streams it back to ``appendMessage``, we persist it so
# C1's edit-from-arbitrary-position can later map this message
# back to the LangGraph checkpoint that produced its turn.
raw_turn_id = raw_body.get("turn_id")
turn_id_value = (
str(raw_turn_id).strip()
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
else None
)
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user.id,
turn_id=turn_id_value,
)
session.add(db_message)
@ -1050,6 +1318,7 @@ async def append_message(
created_at=db_message.created_at,
author_id=db_message.author_id,
token_usage=None,
turn_id=db_message.turn_id,
)
except HTTPException:
@ -1373,43 +1642,123 @@ async def regenerate_response(
user_query_to_use = request.user_query
regenerate_image_urls: list[str] = []
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
# ---------------------------------------------------------------
# Edit-from-arbitrary-position. When the client passes
# ``from_message_id`` we look up its persisted ``turn_id`` (added
# in migration 136) and pick the checkpoint immediately before
# that turn started.
#
# Legacy graceful-degradation contract:
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
# Returning 400 in that case is the wrong UX — the user is
# editing an old message in an existing thread and just wants
# it to work. We instead skip the checkpoint rewind (the
# stream falls back to the latest state) and skip the revert
# pass (no chat_turn_id available to walk). Deletion still
# uses ``created_at``, so the messages-after-cursor slice is
# correct on both legacy and post-136 rows.
# ---------------------------------------------------------------
from_message_turn_id: str | None = None
from_message_created_at: datetime | None = None
legacy_from_message: bool = False
if request.from_message_id is not None:
from_msg_row = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == request.from_message_id,
NewChatMessage.thread_id == thread_id,
)
)
from_msg = from_msg_row.scalars().first()
if from_msg is None:
raise HTTPException(
status_code=404,
detail="from_message_id not found in this thread.",
)
from_message_created_at = from_msg.created_at
if not from_msg.turn_id:
# Legacy row — surface the degradation in logs but let
# the request proceed with the slice-based delete and a
# cold-start checkpoint.
legacy_from_message = True
_logger.warning(
"[regenerate] from_message_id=%s on thread=%s has no "
"turn_id (legacy row pre-migration-136). Falling back "
"to slice-based delete without checkpoint rewind. "
"revert_actions=%s will be ignored.",
request.from_message_id,
thread_id,
request.revert_actions,
)
else:
from_message_turn_id = from_msg.turn_id
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip() or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
# Walk oldest-to-newest and pick the LAST checkpoint whose
# ``turn_id`` differs from the edited turn — that's the state
# immediately before this turn started running. We read from
# ``metadata`` (the durable surface) rather than
# ``config["configurable"]`` so the lookup works across
# checkpointer implementations.
target_checkpoint_id = _find_pre_turn_checkpoint_id(
checkpoint_tuples,
turn_id=from_message_turn_id,
)
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
# Fall back to the oldest checkpoint — better than
# 400ing when the agent didn't checkpoint pre-turn
# (e.g. very first turn of the thread).
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
break
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage.
# We enter this branch when:
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
# * the client pinned ``from_message_id`` but the row is a
# legacy pre-migration-136 row with no ``turn_id`` (we
# downgraded to the same heuristic as a regular reload).
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
# — that's the C1 happy path and the heuristic below would just
# re-derive a worse target.
if request.from_message_id is None or legacy_from_message:
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(
msg.content
)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip()
or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
"checkpoint_id"
]
break
# If we couldn't find a good checkpoint, try alternative approaches
if target_checkpoint_id is None and checkpoint_tuples:
@ -1472,18 +1821,51 @@ async def regenerate_response(
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
# Get the last two messages to delete AFTER streaming succeeds
# This prevents data loss if streaming fails
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
# Get the messages to delete AFTER streaming succeeds.
# This prevents data loss if streaming fails.
#
# When ``from_message_id`` is set we slice from that message
# forward (using ``created_at`` so we also catch any tool/system
# messages persisted into the same turn). Otherwise
# we keep the legacy "last 2 messages" rewind.
if request.from_message_id is not None and from_message_created_at is not None:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.created_at >= from_message_created_at,
)
.order_by(NewChatMessage.created_at.desc())
)
else:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
messages_to_delete = list(last_messages_result.scalars().all())
message_ids_to_delete = [msg.id for msg in messages_to_delete]
# When revert_actions is requested, collect the set of
# ``chat_turn_id``s present in the slice we're about to delete.
# Each one will be reverted (best-effort) BEFORE the regenerate
# stream begins. Legacy rows have ``turn_id=None`` and silently
# contribute nothing — we already logged the degradation above.
revert_turn_ids: list[str] = []
if (
request.revert_actions
and request.from_message_id is not None
and not legacy_from_message
):
seen_turns: set[str] = set()
for msg in messages_to_delete:
tid = msg.turn_id
if tid and tid not in seen_turns:
seen_turns.add(tid)
revert_turn_ids.append(tid)
# Get search space for LLM config
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
@ -1507,6 +1889,24 @@ async def regenerate_response(
# This prevents data loss if streaming fails (network error, LLM error, etc.)
async def stream_with_cleanup():
streaming_completed = False
# Best-effort revert pass BEFORE the regenerate stream begins.
# Each turn is reverted independently (per-row SAVEPOINTs
# inside the route helper) and the per-action results are surfaced
# on a single ``data-revert-results`` SSE event so the frontend
# can render any failed rows alongside the new turn. Failures here
# do NOT abort the regeneration — partial rollback is documented
# behaviour.
if revert_turn_ids:
revert_results = await _revert_turns_for_regenerate(
thread_id=thread_id,
chat_turn_ids=revert_turn_ids,
requester_user_id=str(user.id),
)
envelope = {
"type": "data-revert-results",
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),

View file

@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
author_display_name: str | None = None
author_avatar_url: str | None = None
token_usage: TokenUsageSummary | None = None
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
# ``configurable.turn_id`` at streaming time. Nullable because
# legacy rows predate the column; clients should treat NULL as
# "edit-from-this-message is unavailable".
turn_id: str | None = None
model_config = ConfigDict(from_attributes=True)
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
For edit, optional user_images (when not None) replaces image URLs resolved from
checkpoint/DB so the client can send the full user turn (text and/or images).
Edit-from-arbitrary-position. When ``from_message_id`` is provided
the route slices conversation history starting at that message (instead of
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
optionally reverts every reversible action emitted in turns at or after
``from_message_id``. The revert step is best-effort and runs BEFORE the
regenerate stream partial failures are surfaced via SSE
``data-revert-results`` and do not abort the regeneration.
"""
search_space_id: int
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
default=None,
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
)
from_message_id: int | None = Field(
default=None,
description=(
"Message id to rewind to. When set, history is sliced "
"from this message forward and the LangGraph checkpoint is "
"rewound to the state immediately preceding this turn. Legacy "
"rows that predate migration 136 have ``turn_id=None`` and "
"still process — the route logs a warning, skips the "
"checkpoint rewind, and ignores ``revert_actions`` (no "
"chat_turn_id available to walk)."
),
)
revert_actions: bool = Field(
default=False,
description=(
"When true, every reversible action emitted at or "
"after ``from_message_id`` is reverted before the regenerate "
"stream begins. Per-action results are surfaced via the "
"``data-revert-results`` SSE event. Partial failures DO NOT "
"abort the regeneration."
),
)
@model_validator(mode="after")
def _validate_regenerate_user_images(self) -> Self:
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
return self
@model_validator(mode="after")
def _validate_revert_actions_requires_from_message(self) -> Self:
if self.revert_actions and self.from_message_id is None:
raise ValueError(
"revert_actions requires from_message_id; specify which message to rewind to"
)
return self
# =============================================================================
# Agent Tools Schemas

View file

@ -584,13 +584,33 @@ class VercelStreamingService:
# Tool Parts
# =========================================================================
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
def format_tool_input_start(
self,
tool_call_id: str,
tool_name: str,
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format the start of tool input streaming.
Args:
tool_call_id: The unique tool call identifier
tool_name: The name of the tool being called
tool_call_id: The unique tool call identifier. May be EITHER the
synthetic ``call_<run_id>`` id derived from LangGraph
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
OFF, or the unmatched-fallback path under parity_v2) OR
the authoritative LangChain ``tool_call.id`` (parity_v2
path: when the provider streams ``tool_call_chunks`` we
register the ``index`` and reuse the lc-id as the card
id so live ``tool-input-delta`` events can be routed
without a downstream join). Either way, the same id is
preserved across ``tool-input-start`` / ``-delta`` /
``-available`` / ``tool-output-available`` for one call.
tool_name: The name of the tool being called.
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id``. When set, surfaces as
``langchainToolCallId`` so the frontend can join this card
to the action-log row written by ``ActionLogMiddleware``.
Returns:
str: SSE formatted tool input start part
@ -598,13 +618,14 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
"""
return self._format_sse(
{
"type": "tool-input-start",
"toolCallId": tool_call_id,
"toolName": tool_name,
}
)
payload: dict[str, Any] = {
"type": "tool-input-start",
"toolCallId": tool_call_id,
"toolName": tool_name,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
"""
@ -629,7 +650,12 @@ class VercelStreamingService:
)
def format_tool_input_available(
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
self,
tool_call_id: str,
tool_name: str,
input_data: dict[str, Any],
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format the completion of tool input.
@ -638,6 +664,8 @@ class VercelStreamingService:
tool_call_id: The tool call identifier
tool_name: The name of the tool
input_data: The complete tool input parameters
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id`` (see ``format_tool_input_start``).
Returns:
str: SSE formatted tool input available part
@ -645,22 +673,34 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
"""
return self._format_sse(
{
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input_data,
}
)
payload: dict[str, Any] = {
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input_data,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
def format_tool_output_available(
self,
tool_call_id: str,
output: Any,
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format tool execution output.
Args:
tool_call_id: The tool call identifier
output: The tool execution result
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
When set, the frontend can backfill any card whose
``langchainToolCallId`` was not yet known at
``tool-input-start`` time.
Returns:
str: SSE formatted tool output available part
@ -668,13 +708,14 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
"""
return self._format_sse(
{
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}
)
payload: dict[str, Any] = {
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
# =========================================================================
# Step Parts

View file

@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
written before the original mutation.
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
that was created; everything else is an in-place restore.
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
the inverse tool through the agent's normal permission stack (NOT
bypassed). Out of scope for this PR returns ``REVERSE_NOT_IMPLEMENTED``.
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
A successful revert appends a NEW row to ``agent_action_log`` with
``reverse_of=<original_action_id>`` and the requesting user's
``user_id``, preserving an auditable chain.
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
same trap waiting to happen).
"""
from __future__ import annotations
@ -25,17 +32,31 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Literal
from typing import Any, Literal
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
safe_filename,
safe_folder_segment,
)
from app.db import (
AgentActionLog,
Chunk,
Document,
DocumentRevision,
DocumentType,
Folder,
FolderRevision,
NewChatThread,
)
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
@ -110,14 +131,244 @@ def can_revert(
# ---------------------------------------------------------------------------
# Revert paths
# Helper: reconstruct virtual path from a snapshot
# ---------------------------------------------------------------------------
async def _virtual_path_from_snapshot(
session: AsyncSession,
revision: DocumentRevision,
) -> str | None:
"""Reconstruct the virtual_path the document was at before mutation.
Preference order:
1. ``metadata_before["virtual_path"]`` written by every snapshot
helper since this PR.
2. Compose ``"<folder_path>/<title_before>"`` from
``folder_id_before`` + ``title_before``. Walks the folder chain via
``parent_id``.
"""
metadata = revision.metadata_before or {}
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
return candidate
title = revision.title_before
if not isinstance(title, str) or not title:
return None
parts: list[str] = []
cursor: int | None = revision.folder_id_before
visited: set[int] = set()
while cursor is not None and cursor not in visited:
visited.add(cursor)
folder = await session.get(Folder, cursor)
if folder is None:
return None
parts.append(safe_folder_segment(str(folder.name or "")))
cursor = folder.parent_id
parts.reverse()
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
filename = safe_filename(title)
return f"{base}/{filename}"
# ---------------------------------------------------------------------------
# Document revision restore (write/edit/move/rm)
# ---------------------------------------------------------------------------
def _set_field(target: Any, field: str, value: Any) -> None:
if value is not None:
setattr(target, field, value)
async def _restore_in_place_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Apply an in-place restore to an existing :class:`Document`."""
if revision.document_id is None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Original document was hard-deleted; in-place restore is not possible."
),
)
doc = await session.get(Document, revision.document_id)
if doc is None:
return RevertOutcome(
status="tool_unavailable",
message="Original document has been deleted; revert cannot proceed.",
)
_set_field(doc, "content", revision.content_before)
_set_field(doc, "source_markdown", revision.content_before)
_set_field(doc, "title", revision.title_before)
_set_field(doc, "folder_id", revision.folder_id_before)
metadata_before = revision.metadata_before or {}
if isinstance(metadata_before, dict) and metadata_before:
doc.document_metadata = dict(metadata_before)
if isinstance(revision.content_before, str):
doc.content_hash = generate_content_hash(
revision.content_before, doc.search_space_id
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if virtual_path:
doc.unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
doc.search_space_id,
)
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
if isinstance(revision.content_before, str):
doc.embedding = embed_texts([revision.content_before])[0]
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
async def _reinsert_document_from_revision(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
if not isinstance(revision.title_before, str) or not revision.title_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks title_before; cannot recreate document.",
)
if not isinstance(revision.content_before, str):
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks content_before; cannot recreate document.",
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if not virtual_path:
return RevertOutcome(
status="not_reversible",
message=(
"Snapshot is missing both metadata_before['virtual_path'] AND "
"a resolvable (folder_id_before, title_before) pair."
),
)
search_space_id = revision.search_space_id
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
if collision.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
f"A document already exists at '{virtual_path}'; revert would "
"collide. Move the live doc out of the way first."
),
)
metadata = revision.metadata_before or {}
if not isinstance(metadata, dict):
metadata = {}
metadata = dict(metadata)
metadata["virtual_path"] = virtual_path
content = revision.content_before
new_doc = Document(
title=revision.title_before,
document_type=DocumentType.NOTE,
document_metadata=metadata,
content=content,
content_hash=generate_content_hash(content, search_space_id),
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=search_space_id,
folder_id=revision.folder_id_before,
updated_at=datetime.now(UTC),
)
session.add(new_doc)
await session.flush()
new_doc.embedding = embed_texts([content])[0]
chunk_texts = []
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
]
)
# Repoint the snapshot at the recreated row so a follow-up revert of
# the same row works as expected.
revision.document_id = new_doc.id
return RevertOutcome(
status="ok",
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
)
async def _delete_created_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
if revision.document_id is None:
return RevertOutcome(
status="ok",
message="No live row to delete (already removed elsewhere).",
)
await session.execute(delete(Document).where(Document.id == revision.document_id))
return RevertOutcome(
status="ok",
message="Deleted the document that was created by this action.",
)
async def _restore_document_revision(
session: AsyncSession, *, action: AgentActionLog
) -> RevertOutcome:
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
"""Dispatch document-level revert based on ``action.tool_name``."""
stmt = (
select(DocumentRevision)
.where(DocumentRevision.agent_action_id == action.id)
@ -132,23 +383,111 @@ async def _restore_document_revision(
message="No document_revisions row tied to this action.",
)
from app.db import Document # late import to avoid cycles at module load
tool_name = (action.tool_name or "").lower()
doc = await session.get(Document, revision.document_id)
if doc is None:
if tool_name == "rm":
return await _reinsert_document_from_revision(session, revision=revision)
if tool_name == "write_file" and revision.content_before is None:
return await _delete_created_document(session, revision=revision)
return await _restore_in_place_document(session, revision=revision)
# ---------------------------------------------------------------------------
# Folder revision restore (mkdir/rmdir/rename/move)
# ---------------------------------------------------------------------------
async def _restore_in_place_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="tool_unavailable",
message="Original document has been deleted; revert cannot proceed.",
message="Original folder was hard-deleted; in-place restore is impossible.",
)
folder = await session.get(Folder, revision.folder_id)
if folder is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder has been deleted; revert cannot proceed.",
)
_set_field(folder, "name", revision.name_before)
_set_field(folder, "parent_id", revision.parent_id_before)
_set_field(folder, "position", revision.position_before)
folder.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
async def _reinsert_folder_from_revision(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if not isinstance(revision.name_before, str) or not revision.name_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks name_before; cannot recreate folder.",
)
new_folder = Folder(
name=revision.name_before,
parent_id=revision.parent_id_before,
position=revision.position_before,
search_space_id=revision.search_space_id,
updated_at=datetime.now(UTC),
)
session.add(new_folder)
await session.flush()
revision.folder_id = new_folder.id
return RevertOutcome(
status="ok",
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
)
async def _delete_created_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="ok",
message="No live folder row to delete (already removed elsewhere).",
)
folder_id = revision.folder_id
has_doc = await session.execute(
select(Document.id).where(Document.folder_id == folder_id).limit(1)
)
if has_doc.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (documents have been added since "
"mkdir); cannot revert."
),
)
has_child = await session.execute(
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
)
if has_child.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (sub-folders have been added "
"since mkdir); cannot revert."
),
)
if revision.content_before is not None:
doc.content = revision.content_before
if revision.title_before is not None:
doc.title = revision.title_before
if revision.folder_id_before is not None:
doc.folder_id = revision.folder_id_before
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
await session.execute(delete(Folder).where(Folder.id == folder_id))
return RevertOutcome(
status="ok",
message="Deleted the folder that was created by this action.",
)
async def _restore_folder_revision(
@ -168,41 +507,44 @@ async def _restore_folder_revision(
message="No folder_revisions row tied to this action.",
)
from app.db import Folder
tool_name = (action.tool_name or "").lower()
folder = await session.get(Folder, revision.folder_id)
if folder is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder has been deleted; revert cannot proceed.",
)
if tool_name == "rmdir":
return await _reinsert_folder_from_revision(session, revision=revision)
if revision.name_before is not None:
folder.name = revision.name_before
if revision.parent_id_before is not None:
folder.parent_id = revision.parent_id_before
if revision.position_before is not None:
folder.position = revision.position_before
folder.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
if tool_name == "mkdir":
return await _delete_created_folder(session, revision=revision)
return await _restore_in_place_folder(session, revision=revision)
# Tool-name prefixes that route to KB document / folder revert paths. Kept
# as data so a future PR adding new KB-owned tools doesn't have to touch
# this module's control flow.
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
"edit_file",
"write_file",
"update_memory",
"create_note",
"update_note",
"delete_note",
# ---------------------------------------------------------------------------
# Dispatch
# ---------------------------------------------------------------------------
#
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
# ``delete_note``/``delete_folder``.
_DOC_TOOLS: frozenset[str] = frozenset(
{
"edit_file",
"write_file",
"move_file",
"rm",
"update_memory",
"create_note",
"update_note",
"delete_note",
}
)
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
"mkdir",
"move_file",
"rename_folder",
"delete_folder",
_FOLDER_TOOLS: frozenset[str] = frozenset(
{
"mkdir",
"rmdir",
"rename_folder",
"delete_folder",
}
)
@ -220,9 +562,9 @@ async def revert_action(
"""
tool_name = (action.tool_name or "").lower()
if tool_name.startswith(_DOC_TOOL_PREFIXES):
if tool_name in _DOC_TOOLS:
outcome = await _restore_document_revision(session, action=action)
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
elif tool_name in _FOLDER_TOOLS:
outcome = await _restore_folder_revision(session, action=action)
elif action.reverse_descriptor:
# Connector-owned reversibles run through the normal permission

View file

@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
from app.agents.multi_agent_chat.integration import create_multi_agent_chat
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import (
AgentConfig,
@ -72,6 +73,91 @@ _perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
Returns a dict with three keys:
* ``text`` concatenated string content (empty string if the chunk
contributes none).
* ``reasoning`` concatenated reasoning content (empty string if the
chunk contributes none).
* ``tool_call_chunks`` flat list of LangChain ``tool_call_chunk``
dicts surfaced from either the typed-block list or the
``tool_call_chunks`` attribute.
Background
----------
``AIMessageChunk.content`` can be:
* a ``str`` (most providers), or
* a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
Anthropic, Bedrock, and several reasoning configurations.
Reasoning may also live under
``chunk.additional_kwargs['reasoning_content']`` (some providers
surface it that way instead of as a typed block). Tool-call chunks
may live under ``chunk.tool_call_chunks`` even when ``content`` is a
plain string.
Earlier versions only handled the ``isinstance(content, str)`` branch
and silently dropped reasoning blocks + tool-call chunks emitted by
LangChain ``AIMessageChunk``s.
"""
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
if chunk is None:
return out
content = getattr(chunk, "content", None)
if isinstance(content, str):
if content:
out["text"] = content
elif isinstance(content, list):
text_parts: list[str] = []
reasoning_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
value = block.get("text") or block.get("content") or ""
if isinstance(value, str) and value:
text_parts.append(value)
elif block_type == "reasoning":
value = (
block.get("reasoning")
or block.get("text")
or block.get("content")
or ""
)
if isinstance(value, str) and value:
reasoning_parts.append(value)
elif block_type in ("tool_call_chunk", "tool_use"):
out["tool_call_chunks"].append(block)
if text_parts:
out["text"] = "".join(text_parts)
if reasoning_parts:
out["reasoning"] = "".join(reasoning_parts)
additional = getattr(chunk, "additional_kwargs", None) or {}
if isinstance(additional, dict):
extra_reasoning = additional.get("reasoning_content")
if isinstance(extra_reasoning, str) and extra_reasoning:
existing = out["reasoning"]
out["reasoning"] = (
(existing + extra_reasoning) if existing else extra_reasoning
)
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
if isinstance(extra_tool_chunks, list):
for tcc in extra_tool_chunks:
if isinstance(tcc, dict):
out["tool_call_chunks"].append(tcc)
return out
def format_mentioned_surfsense_docs_as_context(
documents: list[SurfsenseDocsDocument],
) -> str:
@ -254,6 +340,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
)
def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str,
run_id: str,
lc_tool_call_id_by_run: dict[str, str],
) -> str | None:
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
Pure extract of the legacy in-line match used at ``on_tool_start`` for
parity_v2-OFF and unmatched (chunk path didn't register an index for
this call) tools. Pops the next id-bearing chunk whose ``name``
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
returns its id. Mutates ``pending_tool_call_chunks`` and
``lc_tool_call_id_by_run`` in place.
"""
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
return None
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
return candidate
return None
async def _stream_agent_events(
agent: Any,
config: dict[str, Any],
@ -268,6 +390,7 @@ async def _stream_agent_events(
fallback_commit_search_space_id: int | None = None,
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -300,6 +423,59 @@ async def _stream_agent_events(
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
called_update_memory: bool = False
# Reasoning-block streaming. We open a reasoning block on the
# first reasoning delta of a step, append deltas as they arrive, and
# close it when text starts (the model has switched to writing its
# answer) or ``on_chat_model_end`` fires for the model node. Reuses
# the same Vercel format-helpers as text-start/delta/end.
current_reasoning_id: str | None = None
# Streaming-parity v2 feature flag. When OFF we keep the legacy
# shape: str-only content, no reasoning blocks, no
# ``langchainToolCallId`` propagation. The schema migrations
# (135 / 136) ship unconditionally because they're forward-compatible.
parity_v2 = bool(get_flags().enable_stream_parity_v2)
# Best-effort attach of LangChain ``tool_call_id`` to the synthetic
# ``call_<run_id>`` card id we already emit. We accumulate
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
# name, and pop the next unconsumed entry at ``on_tool_start``. The
# authoritative id is later filled in at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
# this list for chunks that already registered into ``index_to_meta``
# below — so this list is reserved for the parity_v2-OFF / unmatched
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
# ``ToolCallChunk``s for the same call share an index but only the
# first chunk carries id+name (subsequent ones are id=None,
# name=None, args="<delta>"). We register an index when both id and
# name are observed on a chunk (per ToolCallChunk semantics they
# arrive together on the first chunk), then route every later chunk
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
# the same card.
index_to_meta: dict[int, dict[str, str]] = {}
ui_tool_call_id_by_run: dict[str, str] = {}
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
# ``format_tool_output_available`` call automatically carries the
# authoritative id without duplicating the kwarg at every call site.
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
def _emit_tool_output(call_id: str, output: Any) -> str:
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=current_lc_tool_call_id["value"],
)
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
@ -328,22 +504,119 @@ async def _stream_agent_events(
if "surfsense:internal" in event.get("tags", []):
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
content = chunk.content
if content and isinstance(content, str):
if current_text_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, content)
accumulated_text += content
if not chunk:
continue
parts = _extract_chunk_parts(chunk)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
# Reasoning streaming. Open a reasoning block on first
# delta; append every subsequent delta until text begins.
# When text starts we close the reasoning block first so the
# frontend sees the natural hand-off. Gated behind the
# parity-v2 flag so legacy deployments keep today's shape.
if parity_v2 and reasoning_delta:
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(current_reasoning_id)
yield streaming_service.format_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if text_delta:
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(current_reasoning_id)
current_reasoning_id = None
if current_text_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
# wire order (text → text-end → tool-input-start). Active
# text/reasoning are closed inside the registration branch
# before ``tool-input-start`` so the frontend sees a clean
# part boundary even when providers interleave.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
# Register this index when we first see id+name
# TOGETHER. Per LangChain ToolCallChunk semantics the
# first chunk for a tool call carries both fields
# together; later chunks have id=None, name=None and
# only ``args``. Requiring BOTH keeps wire
# ``tool-input-start`` always carrying a real
# toolName (assistant-ui's typed tool-part dispatch
# keys off it).
if idx is not None and idx not in index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
# Close active text/reasoning so wire
# ordering stays clean even on providers
# that interleave text and tool-call chunks
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
# index is owned by ``index_to_meta`` we DO NOT
# append to ``pending_tool_call_chunks`` — that list
# is reserved for the parity_v2-OFF / unmatched
# fallback path so it never re-pops chunks already
# consumed here (skip-append).
meta = index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
elif event_type == "on_tool_start":
active_tool_depth += 1
@ -463,6 +736,95 @@ async def _stream_agent_events(
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "rm":
rm_path = (
tool_input.get("path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_path = rm_path if len(rm_path) <= 80 else "" + rm_path[-77:]
last_active_step_title = "Deleting file"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Deleting file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "rmdir":
rmdir_path = (
tool_input.get("path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_path = (
rmdir_path if len(rmdir_path) <= 80 else "" + rmdir_path[-77:]
)
last_active_step_title = "Deleting folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Deleting folder",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "mkdir":
mkdir_path = (
tool_input.get("path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_path = (
mkdir_path if len(mkdir_path) <= 80 else "" + mkdir_path[-77:]
)
last_active_step_title = "Creating folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Creating folder",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "move_file":
src = (
tool_input.get("source_path", "")
if isinstance(tool_input, dict)
else ""
)
dst = (
tool_input.get("destination_path", "")
if isinstance(tool_input, dict)
else ""
)
display_src = src if len(src) <= 60 else "" + src[-57:]
display_dst = dst if len(dst) <= 60 else "" + dst[-57:]
last_active_step_title = "Moving file"
last_active_step_items = (
[f"{display_src}{display_dst}"] if src or dst else []
)
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Moving file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "write_todos":
todos = (
tool_input.get("todos", []) if isinstance(tool_input, dict) else []
)
todo_count = len(todos) if isinstance(todos, list) else 0
last_active_step_title = "Planning tasks"
last_active_step_items = (
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
if todo_count
else []
)
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Planning tasks",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "save_document":
doc_title = (
tool_input.get("title", "")
@ -570,7 +932,15 @@ async def _stream_agent_events(
items=last_active_step_items,
)
else:
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
# Fallback for tools without a curated thinking-step title
# (typically connector tools, MCP-registered tools, or
# newly added tools that haven't been wired up here yet).
# Render the snake_cased name as a sentence-cased phrase
# so non-technical users see e.g. "Send gmail email"
# rather than the raw identifier "send_gmail_email".
last_active_step_title = (
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
last_active_step_items = []
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
@ -578,12 +948,65 @@ async def _stream_agent_events(
status="in_progress",
)
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
# Resolve the card identity. If the chunk-emission loop
# already registered an ``index`` for this tool call (parity_v2
# path), reuse the same ui_id so the card sees:
# tool-input-start → deltas… → tool-input-available →
# tool-output-available all keyed by lc_id. Otherwise fall
# back to the synthetic ``call_<run_id>`` id and the legacy
# best-effort match against ``pending_tool_call_chunks``.
matched_meta: dict[str, str] | None = None
if parity_v2:
# FIFO over indices 0,1,2…; first unassigned same-name
# match wins. Handles parallel same-name calls (e.g. two
# write_file calls) deterministically as long as the
# model interleaves on_tool_start in the same order it
# streamed the args.
taken_ui_ids = set(ui_tool_call_id_by_run.values())
for meta in index_to_meta.values():
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
matched_meta = meta
break
tool_call_id: str
langchain_tool_call_id: str | None = None
if matched_meta is not None:
tool_call_id = matched_meta["ui_id"]
langchain_tool_call_id = matched_meta["lc_id"]
# ``tool-input-start`` already fired during chunk
# emission — skip the duplicate. No pruning is needed
# because the chunk-emission loop intentionally never
# appends registered-index chunks to
# ``pending_tool_call_chunks`` (skip-append).
if run_id:
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
# provider didn't stream tool_call_chunks for this call
# (no index registered). Run the existing best-effort
# match BEFORE emitting start so we still attach an
# authoritative ``langchainToolCallId`` when possible.
if parity_v2:
langchain_tool_call_id = _legacy_match_lc_id(
pending_tool_call_chunks,
tool_name,
run_id,
lc_tool_call_id_by_run,
)
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
@ -600,6 +1023,7 @@ async def _stream_agent_events(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
elif event_type == "on_tool_end":
@ -635,12 +1059,42 @@ async def _stream_agent_events(
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
# Look up the SAME card id used at on_tool_start (either the
# parity_v2 lc-id-derived ui_id or the legacy synthetic
# ``call_<run_id>``) so the output event always lands on the
# same card as start/delta/available. Fallback preserves the
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
tool_call_id = ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
)
completed_step_ids.add(original_step_id)
# Authoritative LangChain tool_call_id from the returned
# ``ToolMessage``. Falls back to whatever we matched
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
# if the output isn't a ToolMessage. The value is stored in
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
# picks it up for every output emit below.
#
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
# card needs the LangChain id to match against the
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
# so the inline Revert button can light up. Reading
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
# access that is safe regardless of feature-flag state.
current_lc_tool_call_id["value"] = None
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
@ -676,6 +1130,41 @@ async def _stream_agent_events(
status="completed",
items=last_active_step_items,
)
elif tool_name == "rm":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Deleting file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rmdir":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Deleting folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "mkdir":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Creating folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "move_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Moving file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_todos":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Planning tasks",
status="completed",
items=last_active_step_items,
)
elif tool_name == "save_document":
result_str = (
tool_output.get("result", "")
@ -927,9 +1416,14 @@ async def _stream_agent_events(
items=completed_items,
)
else:
# Fallback completion title — see the matching in-progress
# branch above for the wording rationale.
fallback_title = (
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title=f"Using {tool_name.replace('_', ' ')}",
title=fallback_title,
status="completed",
items=last_active_step_items,
)
@ -940,7 +1434,7 @@ async def _stream_agent_events(
last_active_step_items = []
if tool_name == "generate_podcast":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -965,7 +1459,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_video_presentation":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -993,7 +1487,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_image":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1020,12 +1514,12 @@ async def _stream_agent_events(
display_output["content_preview"] = (
content[:500] + "..." if len(content) > 500 else content
)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
display_output,
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"result": tool_output},
)
@ -1053,7 +1547,7 @@ async def _stream_agent_events(
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"status": "error",
@ -1062,7 +1556,7 @@ async def _stream_agent_events(
},
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"status": "completed",
@ -1072,7 +1566,7 @@ async def _stream_agent_events(
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1099,7 +1593,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_resume":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1150,7 +1644,7 @@ async def _stream_agent_events(
"update_confluence_page",
"delete_confluence_page",
):
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1178,7 +1672,7 @@ async def _stream_agent_events(
if fpath and fpath not in result.sandbox_files:
result.sandbox_files.append(fpath)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"exit_code": exit_code,
@ -1213,12 +1707,12 @@ async def _stream_agent_events(
citations[chunk_url]["snippet"] = (
content[:200] + "" if len(content) > 200 else content
)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"status": "completed", "citations": citations},
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
@ -1276,6 +1770,25 @@ async def _stream_agent_events(
},
)
elif event_type == "on_custom_event" and event.get("name") == "action_log":
# Surface a freshly committed AgentActionLog row so the chat
# tool card can render its Revert button immediately.
data = event.get("data", {})
if data.get("id") is not None:
yield streaming_service.format_data("action-log", data)
elif (
event_type == "on_custom_event"
and event.get("name") == "action_log_updated"
):
# Reversibility flipped in kb_persistence after the SAVEPOINT
# for a destructive op (rm/rmdir/move/edit/write) committed.
# Frontend uses this to flip the card's Revert
# button on without re-fetching the actions list.
data = event.get("data", {})
if data.get("id") is not None:
yield streaming_service.format_data("action-log-updated", data)
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -1293,11 +1806,12 @@ async def _stream_agent_events(
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
# (dirty_paths / staged_dirs / pending_moves) will still be in the
# checkpointed state. Run the SAME shared commit helper here so the
# turn's writes don't get lost on client disconnect, then push the
# delta back into the graph using `as_node=...` so reducers fire as if
# the after_agent hook produced it.
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
# pending_dir_deletes) will still be in the checkpointed state. Run
# the SAME shared commit helper here so the turn's writes don't get
# lost on client disconnect, then push the delta back into the graph
# using `as_node=...` so reducers fire as if the after_agent hook
# produced it.
if (
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
and fallback_commit_search_space_id is not None
@ -1305,6 +1819,8 @@ async def _stream_agent_events(
(state_values.get("dirty_paths") or [])
or (state_values.get("staged_dirs") or [])
or (state_values.get("pending_moves") or [])
or (state_values.get("pending_deletes") or [])
or (state_values.get("pending_dir_deletes") or [])
)
):
try:
@ -1313,6 +1829,7 @@ async def _stream_agent_events(
search_space_id=fallback_commit_search_space_id,
created_by_id=fallback_commit_created_by_id,
filesystem_mode=fallback_commit_filesystem_mode,
thread_id=fallback_commit_thread_id,
dispatch_events=False,
)
if delta:
@ -1753,13 +2270,33 @@ async def stream_new_chat(
config = {
"configurable": configurable,
"recursion_limit": 80, # Increase from default 25 to allow more tool iterations
# Effectively uncapped, matching the agent-level
# ``with_config`` default in ``chat_deepagent.create_agent``
# and the unbounded ``while(true)`` loop used by OpenCode's
# ``session/processor.ts``. Real circuit-breakers live in
# middleware: ``DoomLoopMiddleware`` (sliding-window tool
# signature check), plus ``enable_tool_call_limit`` /
# ``enable_model_call_limit`` when those flags are set. The
# original LangGraph default of 25 (and our previous 80
# bump) hit users on legitimate multi-tool plans.
"recursion_limit": 10_000,
}
# Start the message stream
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
# Surface the per-turn correlation id at the very start of the
# stream so the frontend can stamp it onto the in-flight
# assistant message and replay it via ``appendMessage``
# for durable storage. Tool/action-log events DO carry it later,
# but pure-text turns never produce action-log events; this
# event guarantees the frontend learns the turn id regardless.
yield streaming_service.format_data(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
initial_title = "Analyzing referenced content"
@ -1910,6 +2447,7 @@ async def stream_new_chat(
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
@ -2353,11 +2891,22 @@ async def stream_resume_chat(
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
},
"recursion_limit": 80,
# See ``stream_new_chat`` above for rationale: effectively
# uncapped to mirror the agent default and OpenCode's
# session loop. Doom-loop / call-limit middleware enforce
# the real ceiling.
"recursion_limit": 10_000,
}
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
# Same rationale as ``stream_new_chat``: emit the turn id so
# resumed streams can be persisted with their correlation id
# intact.
yield streaming_service.format_data(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
_t_stream_start = time.perf_counter()
_first_event_logged = False
@ -2375,6 +2924,7 @@ async def stream_resume_chat(
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(

View file

@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.tools.registry import ToolDefinition
@dataclass
class _FakeRuntime:
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
to populate the new ``chat_turn_id`` column (see migration 135).
"""
config: dict[str, Any] | None = None
@dataclass
class _FakeRequest:
"""Minimal stand-in for ToolCallRequest used in unit tests."""
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
"args": {"color": "red", "size": 3},
"id": "tc-abc",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
handler = AsyncMock(return_value=result_msg)
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
assert row.error is None
assert row.reverse_descriptor is None
assert row.reversible is False
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
assert row.tool_call_id == "tc-abc"
assert row.turn_id == "tc-abc"
assert row.chat_turn_id == "42:1700000000000"
@pytest.mark.asyncio
async def test_chat_turn_id_none_when_runtime_missing(
self, patch_get_flags, fake_session_factory
) -> None:
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
runtime=None,
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.tool_call_id == "tc-1"
assert row.chat_turn_id is None
@pytest.mark.asyncio
async def test_writes_row_on_failure_and_reraises(
@ -293,6 +333,76 @@ class TestReverseDescriptor:
assert row.reversible is False
class TestActionLogDispatch:
"""Verify ``adispatch_custom_event`` fires after commit."""
@pytest.mark.asyncio
async def test_dispatches_action_log_event_on_success(
self, patch_get_flags, fake_session_factory
) -> None:
_captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "red"},
"id": "tc-evt",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
handler = AsyncMock(return_value=result_msg)
dispatch_mock = AsyncMock()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_awaited_once()
call_args = dispatch_mock.await_args
assert call_args is not None
assert call_args.args[0] == "action_log"
payload = call_args.args[1]
assert payload["lc_tool_call_id"] == "tc-evt"
assert payload["chat_turn_id"] == "42:1700000000000"
assert payload["tool_name"] == "make_widget"
assert payload["reversible"] is False
assert payload["reverse_descriptor_present"] is False
assert payload["error"] is False
@pytest.mark.asyncio
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
"""If commit fails the dispatch is suppressed (no row to surface)."""
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
dispatch_mock = AsyncMock()
def _exploding_session():
raise RuntimeError("DB is down")
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=_exploding_session),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_not_awaited()
class TestArgsTruncation:
@pytest.mark.asyncio
async def test_huge_args_payload_is_truncated(

View file

@ -0,0 +1,122 @@
"""Tests for the desktop-mode safety ruleset.
In desktop mode the agent operates against the user's real disk with no
revision history, so destructive filesystem operations must require
explicit approval. These tests pin the set of tools that get the ``ask``
gate so it cannot silently regress.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
pytestmark = pytest.mark.unit
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
# copy here means the rule contract has a focused regression test even when
# the larger graph-build helper is hard to instantiate in unit tests.
DESKTOP_SAFETY_RULESET = Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
SURFSENSE_DEFAULTS = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
rules = evaluate_many(tool_name, [tool_name], *rulesets)
return aggregate_action(rules)
class TestDesktopSafetyRulesGateDestructiveOps:
@pytest.mark.parametrize(
"tool_name",
["rm", "rmdir", "move_file", "edit_file", "write_file"],
)
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
# surfsense_defaults says "allow */*"; desktop_safety must override
# because it's layered later (last-match-wins).
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask", (
f"{tool_name} must require approval in desktop mode "
f"(no revert path on real disk); got {action!r}"
)
@pytest.mark.parametrize(
"tool_name",
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
)
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
# Read-only and trivially-reversible tools must NOT get gated —
# otherwise every navigation in desktop mode pops an interrupt.
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "allow", (
f"{tool_name} should not be gated in desktop mode; got {action!r}"
)
class TestDesktopSafetyOverridesAllowDefault:
def test_layer_order_last_match_wins(self) -> None:
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
# default would win and the safety net would be inert. This test
# protects against accidentally swapping the rulesets in
# ``_build_compiled_agent_blocking``.
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
# Layered "wrong way" — the broad allow now wins.
assert action == "allow"
# Correct order: defaults < desktop_safety -> ask wins.
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask"
class TestPermissionMiddlewareIntegration:
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
from langchain_core.messages import AIMessage
from app.agents.new_chat.errors import RejectedError
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
# Stub the interrupt to a "reject" decision so we can assert the
# ask path was taken without spinning up the LangGraph runtime.
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
state = {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": "rm",
"args": {"path": "/Users/me/Documents/important.docx"},
"id": "tc-rm",
}
],
)
]
}
class _FakeRuntime:
config: dict = {"configurable": {"thread_id": "test"}}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())

View file

@ -0,0 +1,111 @@
"""Tests for the default auto-approval list in ``hitl.request_approval``.
These pin the policy that low-stakes connector creation tools (drafts,
new-file creates) skip the HITL interrupt by default. Without this set,
every "draft my newsletter" turn used to fire ~3 interrupts before any
useful work happened.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.tools.hitl import (
DEFAULT_AUTO_APPROVED_TOOLS,
HITLResult,
request_approval,
)
pytestmark = pytest.mark.unit
class TestDefaultAutoApprovedToolsList:
def test_set_contains_expected_creation_tools(self) -> None:
# If anyone changes the policy list, we want a single test to
# update so the contract is explicit. Keep this in sync with
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
expected = {
"create_gmail_draft",
"update_gmail_draft",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
def test_set_is_immutable(self) -> None:
# frozenset prevents accidental at-runtime mutation that would
# silently widen the auto-approval surface.
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
# External-broadcast tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
"create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
f"{tool_name} must remain HITL-gated"
)
class TestRequestApprovalAutoBypass:
def test_auto_approved_tool_skips_interrupt(self) -> None:
# No interrupt mock set up — if the function attempted to call
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
# The fact that we get a clean HITLResult proves the bypass.
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
assert isinstance(result, HITLResult)
assert result.rejected is False
assert result.decision_type == "auto_approved"
# Original params are preserved untouched (no user edits possible).
assert result.params == {
"to": "alice@example.com",
"subject": "hi",
"body": "hey",
}
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
# A tool NOT in the default list must reach ``langgraph.interrupt``.
# Outside a runnable context that call raises a RuntimeError —
# which is exactly the signal we want: the bypass did NOT fire.
with pytest.raises(RuntimeError, match="runnable context"):
request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
def test_user_trusted_tools_still_take_precedence(self) -> None:
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
# was checked BEFORE the default list and must keep working
# for tools outside the default list.
result = request_approval(
action_type="mcp_tool_call",
tool_name="my_custom_mcp_tool",
params={"x": 1},
trusted_tools=["my_custom_mcp_tool"],
)
assert result.decision_type == "trusted"
assert result.rejected is False
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
# When trusted_tools is empty and tool is in the default list,
# we should still bypass — proves the order in request_approval.
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={"title": "Plan"},
trusted_tools=[],
)
assert result.decision_type == "auto_approved"

View file

@ -0,0 +1,333 @@
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
The tools build ``Command(update=...)`` payloads that the persistence
middleware applies at end of turn. These tests stub out the backend and
runtime to assert the staging payload shape:
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
* ``KBPostgresBackend`` view-helpers honor staged deletes.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock
import pytest
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
pytestmark = pytest.mark.unit
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = mode
middleware._custom_tool_descriptions = {}
return middleware
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
state = state or {}
state.setdefault("cwd", "/documents")
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
class _KBBackendStub(KBPostgresBackend):
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
We bypass the real ``__init__`` (which expects a runtime + DB session)
and inject just the methods the rm/rmdir tools touch. The class
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
inside the tools happy, which is what gates them from the desktop
code path.
"""
def __init__(self, *, children=None, file_data=None) -> None:
self.als_info = AsyncMock(return_value=children or [])
self._load_file_data = AsyncMock(
return_value=(file_data, 17) if file_data is not None else None
)
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
return _KBBackendStub(children=children, file_data=file_data)
def _bind_backend(middleware, backend):
"""Inject a backend resolver onto the middleware test instance."""
middleware._get_backend = lambda runtime: backend
return backend
# ---------------------------------------------------------------------------
# rm
# ---------------------------------------------------------------------------
class TestRmStaging:
@pytest.mark.asyncio
async def test_stages_delete_and_tombstones_state(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"cwd": "/documents",
"files": {"/documents/notes.md": {"content": ["hello"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
},
tool_call_id="tc-1",
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert hasattr(result, "update"), f"expected Command, got {result!r}"
update = result.update
assert update["pending_deletes"] == [
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
]
assert update["files"] == {"/documents/notes.md": None}
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
@pytest.mark.asyncio
async def test_rejects_documents_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_staged_dirs(self):
m = _make_middleware()
runtime = _runtime(
{
"staged_dirs": ["/documents/team-x"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/team-x", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
assert "rmdir" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_listing(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/foo", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
@pytest.mark.asyncio
async def test_rejects_anonymous_doc(self):
m = _make_middleware()
runtime = _runtime(
{
"kb_anon_doc": {
"path": "/documents/uploaded.xml",
"title": "uploaded",
"content": "",
"chunks": [],
}
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
assert isinstance(result, str)
assert "read-only" in result
@pytest.mark.asyncio
async def test_drops_path_from_dirty_paths(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"files": {"/documents/notes.md": {"content": ["x"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
"dirty_paths": ["/documents/notes.md"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
update = result.update
# First element is _CLEAR sentinel; the rest must NOT contain the
# rm'd path.
dirty = update.get("dirty_paths") or []
assert "/documents/notes.md" not in dirty[1:]
# ---------------------------------------------------------------------------
# rmdir
# ---------------------------------------------------------------------------
class TestRmdirStaging:
@pytest.mark.asyncio
async def test_stages_dir_delete_when_empty_and_db_backed(self):
m = _make_middleware()
backend = _bind_backend(m, _make_backend_stub(children=[]))
# Override _load_file_data to return None (folder, not a file) and
# parent listing to claim the folder exists.
backend._load_file_data = AsyncMock(return_value=None)
backend.als_info = AsyncMock(
side_effect=[
[], # children of /documents/proj
[
{"path": "/documents/proj", "is_dir": True},
], # parent listing
]
)
runtime = _runtime(
{
"cwd": "/documents",
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert update["pending_dir_deletes"] == [
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
]
@pytest.mark.asyncio
async def test_rejects_non_empty(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "not empty" in result
@pytest.mark.asyncio
async def test_unstages_same_turn_mkdir(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[]))
runtime = _runtime(
{
"cwd": "/documents",
"staged_dirs": ["/documents/scratch"],
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/scratch", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert "pending_dir_deletes" not in update
# _CLEAR sentinel + remaining items (in this case, none).
staged_after = update["staged_dirs"]
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
assert "/documents/scratch" not in staged_after[1:]
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rmdir_tool()
for victim in ("/", "/documents"):
result = await tool.coroutine(victim, runtime=runtime)
assert isinstance(result, str)
assert "refusing to rmdir" in result
@pytest.mark.asyncio
async def test_rejects_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_ancestor_of_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj/sub"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_files(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert isinstance(result, str)
assert "is a file" in result
# ---------------------------------------------------------------------------
# KBPostgresBackend view filter
# ---------------------------------------------------------------------------
class TestKBPostgresBackendDeleteFilter:
"""als_info / glob / grep should suppress paths queued for delete."""
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
runtime = SimpleNamespace(state=state)
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
return backend
def test_pending_filesystem_view_returns_deleted_paths(self):
backend = self._make_backend(
{
"pending_deletes": [
{"path": "/documents/x.md", "tool_call_id": "t1"},
],
"pending_dir_deletes": [
{"path": "/documents/d1", "tool_call_id": "t2"},
],
}
)
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
assert "/documents/x.md" in removed
assert "/documents/d1" in deleted_dirs
assert alias == {}
def test_dir_suppressed_covers_descendants(self):
backend = self._make_backend({})
deleted_dirs = {"/documents/d"}
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)

View file

@ -98,10 +98,54 @@ class TestInitialFilesystemState:
state = _initial_filesystem_state()
assert state["cwd"] == "/documents"
assert state["staged_dirs"] == []
assert state["staged_dir_tool_calls"] == {}
assert state["pending_moves"] == []
assert state["pending_deletes"] == []
assert state["pending_dir_deletes"] == []
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0
class TestMultiEditSamePathCoalescing:
"""Multi-edit-same-path turns must coalesce into ONE binding record.
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
tool_call_id that produced the current state on disk. Because
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
edit doesn't append a new path entry — and because
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
overwrite, the LATEST tool_call_id wins. That's the correct behavior
for snapshotting: revert restores to the pre-mutation state, and
multiple back-to-back edits in one turn coalesce into a single
revisible op (the user sees ONE Revert button per turn-per-path,
not N).
"""
def test_dirty_paths_dedupes_repeated_writes(self):
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
# to the same path produce one entry, not two.
first = _add_unique_reducer([], ["/documents/a.md"])
second = _add_unique_reducer(first, ["/documents/a.md"])
assert second == ["/documents/a.md"]
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
# First write tags the path with tcid-1.
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
# Second write to the same path tags it with tcid-2 (latest wins).
merged = _dict_merge_with_tombstones_reducer(
merged, {"/documents/a.md": "tcid-2"}
)
assert merged == {"/documents/a.md": "tcid-2"}
def test_rm_tombstones_dirty_path_tool_call(self):
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
# prevent a stale binding from leaking past the delete.
merged = _dict_merge_with_tombstones_reducer(
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
)
assert merged == {}

View file

@ -0,0 +1,83 @@
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
A full apply/rollback test would require a live Postgres; here we verify
the migration module's static contract:
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
(one for ``document_revisions.document_id``, one for
``folder_revisions.folder_id``).
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
orphaned revisions.
If any of these invariants regress the snapshot/revert pipeline silently
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
migration "down" or never ran it at all.
"""
from __future__ import annotations
import importlib.util
import inspect
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
_MIGRATION_PATH = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "134_relax_revision_fks.py"
)
def _load_migration():
"""Load the migration module by file path (no package import needed)."""
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
assert spec and spec.loader, "could not load migration spec"
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def test_migration_chain_revision_ids() -> None:
module = _load_migration()
# The migration file uses short numeric revision IDs to match the
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
# filename is documentation, not the canonical revision string.
assert getattr(module, "revision", None) == "134"
assert getattr(module, "down_revision", None) == "133"
def test_migration_exposes_upgrade_and_downgrade() -> None:
module = _load_migration()
upgrade = getattr(module, "upgrade", None)
downgrade = getattr(module, "downgrade", None)
assert callable(upgrade), "upgrade() is required"
assert callable(downgrade), "downgrade() is required"
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
module = _load_migration()
src = inspect.getsource(module.upgrade)
assert "document_revisions" in src
assert "folder_revisions" in src
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
# of the migration: snapshots must outlive their parent row.
assert src.count('ondelete="SET NULL"') >= 2
# And the ``document_id`` / ``folder_id`` columns become nullable.
assert "nullable=True" in src
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
module = _load_migration()
src = inspect.getsource(module.downgrade)
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
# Then restore the original CASCADE/NOT NULL contract.
assert src.count('ondelete="CASCADE"') >= 2
assert "nullable=False" in src

View file

@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
assert "/documents/" not in text, f"{name} mentions cloud namespace"
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
def test_cloud_descs_include_rm_and_rmdir(self):
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
assert "rm" in descs and "rmdir" in descs
assert "Deletes a single file" in descs["rm"]
assert "Deletes an empty directory" in descs["rmdir"]
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
def test_desktop_descs_warn_about_irreversibility(self):
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
assert "NOT reversible" in descs["rm"]
assert "NOT reversible" in descs["rmdir"]
def test_sandbox_addendum_appended_when_available(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=True

View file

@ -0,0 +1,309 @@
"""Unit tests for the kb_persistence snapshot helpers.
The full ``commit_staged_filesystem_state`` body exercises a real session
in integration tests; here we verify the building blocks used by the
snapshot/revert pipeline:
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
(regression guard against the N+1 lookup pattern).
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
shape the snapshot helpers consume.
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
the assertions run in milliseconds and don't require Postgres.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat.middleware import kb_persistence
pytestmark = pytest.mark.unit
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
@pytest.mark.asyncio
async def test_find_action_ids_batch_issues_single_query() -> None:
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(id=11, tool_call_id="tc-a"),
MagicMock(id=22, tool_call_id="tc-b"),
MagicMock(id=33, tool_call_id="tc-c"),
]
)
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=1,
tool_call_ids={"tc-a", "tc-b", "tc-c"},
)
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
assert session.execute.await_count == 1, (
"Snapshot binding must batch into ONE query; got "
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
)
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=None,
tool_call_ids={"tc-a"},
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=42,
tool_call_ids=set(),
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
assert session.execute.await_count == 1
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
doc = MagicMock()
doc.content = "body"
doc.title = "notes.md"
doc.folder_id = 7
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
payload = kb_persistence._doc_revision_payload(
doc, chunks_before=[{"content": "x"}]
)
assert payload["title_before"] == "notes.md"
assert payload["folder_id_before"] == 7
assert payload["content_before"] == "body"
assert payload["chunks_before"] == [{"content": "x"}]
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
def test_doc_revision_payload_handles_missing_metadata() -> None:
doc = MagicMock()
doc.content = ""
doc.title = ""
doc.folder_id = None
doc.document_metadata = None
payload = kb_persistence._doc_revision_payload(doc)
assert payload["metadata_before"] is None
@pytest.mark.asyncio
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(content="alpha"),
MagicMock(content="beta"),
]
)
chunks = await kb_persistence._load_chunks_for_snapshot(
session,
doc_id=42, # type: ignore[arg-type]
)
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
# ---------------------------------------------------------------------------
# Deferred reversibility-flip dispatches.
#
# The snapshot helpers used to dispatch ``action_log_updated`` directly
# from inside the SAVEPOINT block. That meant the SSE side-channel
# could tell the UI a row was reversible while the OUTER transaction
# was still pending — and if the outer commit failed, every SAVEPOINT
# rolled back too, leaving the UI in a state inconsistent with
# durable storage. The deferred-dispatch contract fixes that:
#
# • when a ``deferred_dispatches`` list is provided, the helper
# APPENDS the action_id and does NOT dispatch;
# • the caller (``commit_staged_filesystem_state``) flushes the list
# only AFTER ``await session.commit()`` succeeds; on rollback it
# clears the list so nothing is emitted.
# ---------------------------------------------------------------------------
class _NestedCtx:
"""Async context manager mimicking ``session.begin_nested()``."""
async def __aenter__(self) -> _NestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@pytest.mark.asyncio
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 17
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
doc.title = "x.md"
doc.folder_id = None
doc.content = "body"
rev_id = await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=42,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert rev_id == 17
# Inline dispatch must NOT have fired; the action_id is queued.
assert dispatched == []
assert deferred == [42]
@pytest.mark.asyncio
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 7
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
doc.title = "y.md"
doc.folder_id = None
doc.content = "body"
await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=88,
search_space_id=1,
turn_id="t-1",
# No deferred_dispatches arg — fall back to inline dispatch.
)
assert dispatched == [88]
@pytest.mark.asyncio
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock() # _mark_action_reversible calls execute
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 3
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
await kb_persistence._snapshot_folder_pre_mkdir(
session, # type: ignore[arg-type]
folder=folder,
action_id=55,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert dispatched == []
assert deferred == [55]

View file

@ -0,0 +1,139 @@
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
The empty-folder marker is critical UX: without it, the LLM cannot
distinguish a leaf folder containing one document from a leaf folder
that has no descendants at all, and ends up firing ``rmdir`` on
non-empty folders. These tests pin the rendering contract so that
contract cannot silently regress.
"""
from __future__ import annotations
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
class TestComputeNonEmptyFolders:
def test_folder_with_direct_document_is_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths = [
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
]
non_empty = _compute(folder_paths, doc_paths)
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
def test_truly_empty_leaf_folder_is_not_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths: list[str] = []
assert _compute(folder_paths, doc_paths) == set()
def test_documents_propagate_up_to_all_ancestors(self):
folder_paths = [
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
]
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
non_empty = _compute(folder_paths, doc_paths)
assert non_empty == {
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
}
def test_chain_with_subfolders_marks_only_leaf_empty(self):
# POSIX-like semantic: a folder is "empty" only if it has no
# immediate children (docs OR sub-folders). The model needs this
# because parallel ``rmdir`` calls all see the same starting state,
# so trying to rmdir a parent before its children is never safe.
folder_paths = [
f"{DOCUMENTS_ROOT}/X",
f"{DOCUMENTS_ROOT}/X/Y",
f"{DOCUMENTS_ROOT}/X/Y/Z",
]
non_empty = _compute(folder_paths, [])
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
# sub-folder child, so they are non-empty and should NOT carry the
# ``(empty)`` marker.
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
# Mirrors a real DB layout where every intermediate folder is
# materialized in the ``folders`` table.
folder_paths = [
f"{DOCUMENTS_ROOT}/Travel",
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
f"{DOCUMENTS_ROOT}/Travel/Notes",
]
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
non_empty = _compute(folder_paths, doc_paths)
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
class TestFormatTreeRendering:
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
def _render(
self,
folder_paths: list[str],
doc_specs: list[dict],
) -> str:
from app.agents.new_chat.path_resolver import PathIndex
index = PathIndex(
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
)
class _Row:
def __init__(self, **kw):
self.__dict__.update(kw)
docs = [_Row(**spec) for spec in doc_specs]
mw = KnowledgeTreeMiddleware(
search_space_id=1,
filesystem_mode=None, # type: ignore[arg-type]
)
return mw._format_tree(index, docs)
def test_renders_empty_marker_only_for_truly_empty_folders(self):
# Reproduces the failure scenario from the bug report:
# ``Boarding Pass`` is empty (its only doc was just deleted), while
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
# folders are present in the index, mirroring the real DB layout.
folder_paths = [
"/documents/File Upload",
"/documents/File Upload/2026-04-08",
"/documents/File Upload/2026-04-08/Travel",
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
"/documents/File Upload/2026-04-15",
"/documents/File Upload/2026-04-15/Finance",
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
]
tax_returns_folder_id = (
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
+ 1
)
rendered = self._render(
folder_paths=folder_paths,
doc_specs=[
{
"id": 100,
"title": "federal.pdf",
"folder_id": tax_returns_folder_id,
},
],
)
assert "Boarding Pass/ (empty)" in rendered
assert "Tax Returns/ (empty)" not in rendered
# Intermediate ancestors of the doc must NOT be marked empty.
assert "Finance/ (empty)" not in rendered
assert "2026-04-15/ (empty)" not in rendered

View file

@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
assert write.error is not None
assert "parent directory" in write.error
assert not (tmp_path / "tempoo").exists()
def test_local_backend_delete_file_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "delete-me.md").write_text("bye")
res = backend.delete_file("/delete-me.md")
assert res.error is None
assert res.path == "/delete-me.md"
assert not (tmp_path / "delete-me.md").exists()
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "subdir").mkdir()
res = backend.delete_file("/subdir")
assert res.error is not None
assert "directory" in res.error
assert (tmp_path / "subdir").exists()
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
res = backend.delete_file("/nope.md")
assert res.error is not None
assert "not found" in res.error
def test_local_backend_rmdir_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "empty").mkdir()
res = backend.rmdir("/empty")
assert res.error is None
assert res.path == "/empty"
assert not (tmp_path / "empty").exists()
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "withkid").mkdir()
(tmp_path / "withkid" / "child.md").write_text("x")
res = backend.rmdir("/withkid")
assert res.error is not None
assert "not empty" in res.error
assert (tmp_path / "withkid" / "child.md").exists()
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "f.md").write_text("x")
res = backend.rmdir("/f.md")
assert res.error is not None
assert "not a directory" in res.error
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
"""``rmdir /`` MUST fail. The exact error wording comes from
``_resolve_virtual`` (root resolves to outside the sandbox); what
matters is that the call returns an error and does NOT delete the
sandbox root on disk."""
backend = LocalFolderBackend(str(tmp_path))
res = backend.rmdir("/")
assert res.error is not None
assert "Invalid path" in res.error or "root" in res.error
assert tmp_path.exists()

View file

@ -0,0 +1,143 @@
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
The regenerate route's edit-from-position path introduces:
* ``_find_pre_turn_checkpoint_id`` walks LangGraph checkpoint tuples
newest-first and picks the first one whose ``metadata["turn_id"]``
differs from the edited turn. That checkpoint is the rewind target
(state immediately before the edited turn started).
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
with a validator that prevents callers from requesting a revert pass
without specifying which turn to roll back.
These are pure-Python helpers that don't need a live DB, so we exercise
them with a small ``CheckpointTuple``-shaped namespace and direct
schema instantiation.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
from app.schemas.new_chat import RegenerateRequest
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
return SimpleNamespace(
config={"configurable": {"checkpoint_id": checkpoint_id}},
metadata={"turn_id": turn_id} if turn_id is not None else {},
)
class TestFindPreTurnCheckpointId:
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
# Newest-first: T2 is the most-recent turn. The latest non-T2
# checkpoint (cp2) is the rewind target — state immediately
# before T2 began.
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
# Regression for the bug where walking newest-first returned the
# FIRST cp with ``turn_id != target`` — which is one of the
# later-turn checkpoints, NOT the pre-turn boundary. Editing
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
# latest T3 checkpoint (cp6).
tuples = [
_cp("cp6", "T3"),
_cp("cp5", "T3"),
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_none_when_editing_first_turn(self) -> None:
# No pre-turn boundary exists; caller is expected to fall back
# to the oldest checkpoint or special-case "first turn of the
# thread".
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
def test_returns_none_when_only_edited_turn_present(self) -> None:
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
def test_returns_none_for_empty_history(self) -> None:
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
# Checkpoints written before migration 136 have no
# ``metadata.turn_id``. They should be eligible rewind targets
# — they came before the
# edited turn began.
tuples = [
_cp("cp3", "T2"),
SimpleNamespace(
config={"configurable": {"checkpoint_id": "cp2"}},
metadata=None,
),
_cp("cp1", "T1"),
]
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
# then cp3(T2) crosses the boundary -> return cp2.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
# If a checkpoint tuple's ``config["configurable"]`` is missing
# the ``checkpoint_id`` key (corrupt / partial), we keep the
# last known good target instead of crashing.
broken = SimpleNamespace(
config={"configurable": {}}, metadata={"turn_id": "T1"}
)
tuples = [
_cp("cp3", "T2"),
broken,
_cp("cp1", "T1"),
]
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
class TestRegenerateRequestValidation:
def test_revert_actions_requires_from_message_id(self) -> None:
with pytest.raises(Exception) as exc:
RegenerateRequest(
search_space_id=1,
user_query="hi",
revert_actions=True,
)
msg = str(exc.value).lower()
assert "from_message_id" in msg
def test_from_message_id_without_revert_is_allowed(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
)
assert req.from_message_id == 42
assert req.revert_actions is False
def test_revert_actions_with_from_message_id_passes(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
revert_actions=True,
)
assert req.revert_actions is True

View file

@ -0,0 +1,530 @@
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
The per-turn batch revert route walks rows in reverse ``created_at``
order, reverts each independently, and returns a per-action result
list. Partial success is normal the response status
is ``"partial"`` whenever any row could not be reverted, but we never
collapse the whole batch into a 4xx.
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
session, so they exercise the route's dispatch logic without a real DB.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.routes import agent_revert_route
from app.services.revert_service import RevertOutcome
@dataclass
class _FakeAction:
id: int
tool_name: str
user_id: str | None = "u1"
reverse_of: int | None = None
error: dict | None = None
@dataclass
class _FakeUser:
id: str = "u1"
@dataclass
class _ScalarResult:
rows: list[Any]
def first(self) -> Any:
return self.rows[0] if self.rows else None
def all(self) -> list[Any]:
return list(self.rows)
@dataclass
class _Result:
rows: list[Any] = field(default_factory=list)
def scalars(self) -> _ScalarResult:
return _ScalarResult(self.rows)
def all(self) -> list[Any]:
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
# the row-tuple result (no ``.scalars()`` indirection). The
# rows queued for that helper are list[(revert_id, original_id)].
return list(self.rows)
class _FakeNestedCtx:
"""Async context manager that mimics ``session.begin_nested()``.
The route raises a sentinel exception inside this block to roll back
bad rows. We just pass the exception through.
"""
async def __aenter__(self) -> _FakeNestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
# Returning False (or None) propagates the exception; the route
# catches its own sentinel above this layer.
return False
class _FakeSession:
"""Minimal AsyncSession stand-in for the revert-turn route.
Holds a queue of result objects; each ``execute(...)`` pops the next
one. The route calls ``execute`` exactly once per query so this maps
cleanly onto the assertion order of the test.
"""
def __init__(self) -> None:
self._results: list[_Result] = []
self.committed = False
self.rolled_back = False
# Count execute() calls to assert "no N+1 reverts".
self.execute_call_count = 0
def queue(self, *results: _Result) -> None:
self._results.extend(results)
async def execute(self, _stmt: Any) -> _Result:
self.execute_call_count += 1
if not self._results:
return _Result(rows=[])
return self._results.pop(0)
def begin_nested(self) -> _FakeNestedCtx:
return _FakeNestedCtx()
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
def _enabled_flags() -> AgentFeatureFlags:
return AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=True,
)
@pytest.fixture
def patch_get_flags():
def _patch(flags: AgentFeatureFlags):
return patch(
"app.routes.agent_revert_route.get_flags",
return_value=flags,
)
return _patch
class TestFlagGuard:
@pytest.mark.asyncio
async def test_returns_503_when_revert_route_disabled(
self, patch_get_flags
) -> None:
flags = AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=False,
)
session = _FakeSession()
with patch_get_flags(flags), pytest.raises(Exception) as exc:
await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="42:1700000000000",
session=session,
user=_FakeUser(),
)
assert getattr(exc.value, "status_code", None) == 503
class TestRevertTurnDispatch:
@pytest.mark.asyncio
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
session = _FakeSession()
session.queue(_Result(rows=[])) # rows query returns nothing
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-empty",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 0
assert response.results == []
assert session.committed is True
@pytest.mark.asyncio
async def test_walks_rows_in_reverse_and_reverts_each(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=10, tool_name="rm"),
_FakeAction(id=9, tool_name="write_file"),
_FakeAction(id=8, tool_name="mkdir"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched ``_was_already_reverted_batch`` probe replaces
# the previous N per-row SELECTs.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
return RevertOutcome(
status="ok",
message=f"reverted-{action.id}",
new_action_id=100 + action.id,
)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-3",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 3
assert response.reverted == 3
assert [r.action_id for r in response.results] == [10, 9, 8]
assert all(r.status == "reverted" for r in response.results)
assert response.results[0].new_action_id == 110
# Only TWO ``execute`` calls regardless of the row count: one
# for the rows query, one for the batched
# ``_was_already_reverted_batch`` probe. Regression guard
# against re-introducing the per-row N+1 lookup.
assert session.execute_call_count == 2, (
"revert-turn loop must batch idempotency probes; got "
f"{session.execute_call_count} execute() calls (expected 2)."
)
@pytest.mark.asyncio
async def test_already_reverted_rows_are_marked_idempotent(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=5, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch probe returns ``[(revert_id, original_id)]``.
session.queue(_Result(rows=[(42, 5)]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-i",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 42
revert.assert_not_called()
@pytest.mark.asyncio
async def test_revert_action_skips_existing_revert_rows(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
session = _FakeSession()
session.queue(_Result(rows=rows))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-rev",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.results[0].status == "skipped"
revert.assert_not_called()
@pytest.mark.asyncio
async def test_partial_success_when_some_rows_not_reversible(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=2, tool_name="send_email"),
_FakeAction(id=1, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.tool_name == "send_email":
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
return RevertOutcome(status="ok", message="ok", new_action_id=500)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mix",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.reverted == 1
assert response.not_reversible == 1
statuses = sorted(r.status for r in response.results)
assert statuses == ["not_reversible", "reverted"]
@pytest.mark.asyncio
async def test_unexpected_exception_marks_row_failed_not_batch(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=20, tool_name="edit_file"),
_FakeAction(id=21, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 20:
raise RuntimeError("disk on fire")
return RevertOutcome(status="ok", message="ok", new_action_id=999)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-fail",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.failed == 1
assert response.reverted == 1
bad = next(r for r in response.results if r.action_id == 20)
assert bad.status == "failed"
assert "disk on fire" in (bad.error or "")
good = next(r for r in response.results if r.action_id == 21)
assert good.status == "reverted"
@pytest.mark.asyncio
async def test_permission_denied_when_other_user_owns_action(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch idempotency probe (no prior reverts).
session.queue(_Result(rows=[]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-perm",
session=session,
user=_FakeUser(id="not-owner"),
)
assert response.status == "partial"
assert response.results[0].status == "permission_denied"
# ``permission_denied`` has its own dedicated counter so the
# response invariant ``total == sum(counters)`` always holds
# without overloading ``not_reversible`` (which historically
# absorbed this case and confused frontend toasts).
assert response.permission_denied == 1
assert response.not_reversible == 0
revert.assert_not_called()
@pytest.mark.asyncio
async def test_counter_invariant_holds_across_mixed_outcomes(
self, patch_get_flags
) -> None:
"""Every row is accounted for in EXACTLY ONE counter.
Mixes one of every supported outcome (reverted, already_reverted,
not_reversible, permission_denied, failed, skipped) and asserts
that the sum of counters equals ``response.total``.
"""
rows = [
_FakeAction(id=10, tool_name="edit_file"), # ok
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
_FakeAction(id=8, tool_name="send_email"), # not_reversible
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
_FakeAction(id=6, tool_name="edit_file"), # failed
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched probe; only id=9 has a prior revert.
# Schema: list[(revert_id, original_id)].
session.queue(_Result(rows=[(42, 9)]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 10:
return RevertOutcome(status="ok", message="ok", new_action_id=500)
if action.id == 8:
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
if action.id == 6:
raise RuntimeError("boom")
raise AssertionError(f"unexpected revert call for {action.id}")
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_fake_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mixed-all",
session=session,
user=_FakeUser(), # only id=7 has a different user_id
)
assert response.total == len(rows) == 6
bucket_sum = (
response.reverted
+ response.already_reverted
+ response.not_reversible
+ response.permission_denied
+ response.failed
+ response.skipped
)
assert bucket_sum == response.total, (
"Counter invariant broken: total "
f"({response.total}) != sum of counters ({bucket_sum}). "
f"Counters: reverted={response.reverted}, "
f"already_reverted={response.already_reverted}, "
f"not_reversible={response.not_reversible}, "
f"permission_denied={response.permission_denied}, "
f"failed={response.failed}, skipped={response.skipped}"
)
assert response.reverted == 1
assert response.already_reverted == 1
assert response.not_reversible == 1
assert response.permission_denied == 1
assert response.failed == 1
assert response.skipped == 1
@pytest.mark.asyncio
async def test_integrity_error_translates_to_already_reverted(
self, patch_get_flags
) -> None:
"""The partial unique index on ``reverse_of`` raises
``IntegrityError`` when a concurrent revert wins the race against
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
recover by re-querying for the winning revert id and returning
``status="already_reverted"`` (not ``"failed"``) so racing
clients see consistent idempotent semantics.
"""
from sqlalchemy.exc import IntegrityError
rows = [_FakeAction(id=33, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch pre-flight probe: nothing yet (we'll race).
session.queue(_Result(rows=[]))
# Post-IntegrityError fallback uses the SCALAR
# ``_was_already_reverted`` (single-id lookup) so it pulls
# ``[777]`` via ``.scalars().first()``.
session.queue(_Result(rows=[777]))
async def _racing_revert(_session, *, action, requester_user_id):
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_racing_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-race",
session=session,
user=_FakeUser(),
)
assert response.failed == 0, (
"IntegrityError must NOT surface as a failed row; the unique "
"index is the durable expression of idempotency."
)
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 777

View file

@ -0,0 +1,370 @@
"""Unit tests for the filesystem-tool branches of ``revert_service``.
Covers:
* Exact-name dispatch ``rmdir`` does NOT mis-route to the document
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
prefix-based dispatch).
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
when ``metadata_before["virtual_path"]`` is missing.
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
document.
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
when the folder gained children.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from app.services import revert_service
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
revert_service,
"embed_texts",
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
def scalars(self) -> Any:
return _FakeScalarsProxy(self._rows)
class _FakeScalarsProxy:
def __init__(self, rows: list[Any]) -> None:
self._rows = rows
def first(self) -> Any:
return self._rows[0] if self._rows else None
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
self.added: list[Any] = []
self.deleted: list[Any] = []
self.flush = AsyncMock()
# session.get(Model, pk) lookup
self.get = AsyncMock(return_value=None)
async def _flush_assigning_ids() -> None:
for obj in self.added:
if getattr(obj, "id", None) is None:
obj.id = 999
self.flush.side_effect = _flush_assigning_ids
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: list[Any]) -> None:
self.added.extend(objs)
def _action(*, tool_name: str, action_id: int = 7):
return MagicMock(
id=action_id,
tool_name=tool_name,
thread_id=1,
search_space_id=2,
user_id="user-1",
reverse_descriptor=None,
)
def _doc_revision(
*,
document_id: int | None = None,
content_before: str | None = "old content",
title_before: str | None = "notes.md",
folder_id_before: int | None = 5,
chunks_before: list[dict[str, str]] | None = None,
metadata_before: dict[str, str] | None = None,
):
revision = MagicMock()
revision.id = 100
revision.document_id = document_id
revision.search_space_id = 2
revision.content_before = content_before
revision.title_before = title_before
revision.folder_id_before = folder_id_before
revision.chunks_before = chunks_before or []
revision.metadata_before = metadata_before
return revision
def _folder_revision(
*,
folder_id: int | None = None,
name_before: str | None = "team",
parent_id_before: int | None = None,
position_before: str | None = "a0",
):
revision = MagicMock()
revision.id = 200
revision.folder_id = folder_id
revision.search_space_id = 2
revision.name_before = name_before
revision.parent_id_before = parent_id_before
revision.position_before = position_before
return revision
# ---------------------------------------------------------------------------
# Exact-name dispatch regression guards
# ---------------------------------------------------------------------------
class TestExactDispatch:
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
@pytest.mark.asyncio
async def test_rmdir_does_not_misroute_to_document(self) -> None:
# If dispatch used `startswith("rm")` we'd hit the document branch
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
session = _FakeSession()
action = _action(tool_name="rmdir")
# No folder revisions exist for this action.
session.execute.return_value = _FakeResult(rows=[])
outcome = await revert_service.revert_action(
session, # type: ignore[arg-type]
action=action,
requester_user_id="user-1",
)
assert outcome.status == "not_reversible"
assert "folder_revisions" in outcome.message
def test_dispatch_sets_split_doc_and_folder(self) -> None:
# Static guards on the dispatch tables themselves so a future
# refactor doesn't accidentally reintroduce the prefix bug.
assert "rm" in revert_service._DOC_TOOLS
assert "rmdir" in revert_service._FOLDER_TOOLS
assert "rmdir" not in revert_service._DOC_TOOLS
assert "rm" not in revert_service._FOLDER_TOOLS
# ``move_file`` lives only in document tools (it's a doc rename).
assert "move_file" in revert_service._DOC_TOOLS
assert "move_file" not in revert_service._FOLDER_TOOLS
# ---------------------------------------------------------------------------
# rm revert (re-INSERT)
# ---------------------------------------------------------------------------
class TestRmRevert:
@pytest.mark.asyncio
async def test_re_inserts_document_with_chunks(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None, # row was hard-deleted
content_before="hello world",
title_before="x.md",
folder_id_before=None,
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
metadata_before={"virtual_path": "/documents/x.md"},
)
# No collision check hit and the resulting query returns nothing.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# New Document + 2 chunks must have been added.
from app.db import Chunk, Document
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
assert len(added_docs) == 1
assert added_docs[0].title == "x.md"
assert len(added_chunks) == 2
# Snapshot was repointed at the new doc id so a follow-up revert works.
assert revision.document_id == added_docs[0].id
@pytest.mark.asyncio
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
self,
) -> None:
session = _FakeSession()
# Snapshot with NO metadata_before — the fallback path must kick in.
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="cap.md",
folder_id_before=42,
chunks_before=[],
metadata_before=None,
)
# session.get(Folder, 42) returns a folder with a name.
folder = MagicMock()
folder.name = "team"
folder.parent_id = None
# First .get is for the folder lookup in the path-derivation.
session.get = AsyncMock(return_value=folder)
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_falls_back_to_root_path_when_no_folder(
self,
) -> None:
"""metadata_before is None and folder_id_before is None still
resolves: title fallback yields ``/documents/<title>`` so revert
proceeds at the root of the documents tree."""
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="x.md",
folder_id_before=None,
metadata_before=None,
)
# No collision in the documents tree at /documents/x.md.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hi",
title_before="x.md",
folder_id_before=None,
metadata_before={"virtual_path": "/documents/x.md"},
)
# SELECT for unique_identifier_hash collision hits an existing row.
session.execute.return_value = _FakeResult(scalar=42)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "collide" in outcome.message
# ---------------------------------------------------------------------------
# write_file create revert (DELETE)
# ---------------------------------------------------------------------------
class TestWriteFileCreateRevert:
@pytest.mark.asyncio
async def test_deletes_created_doc(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=99,
content_before=None, # marker for "created in this action"
title_before=None,
)
outcome = await revert_service._delete_created_document(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# Exactly one DELETE was issued.
assert session.execute.await_count == 1
# ---------------------------------------------------------------------------
# rmdir revert (re-INSERT folder)
# ---------------------------------------------------------------------------
class TestRmdirRevert:
@pytest.mark.asyncio
async def test_re_inserts_folder_from_snapshot(self) -> None:
session = _FakeSession()
revision = _folder_revision(
folder_id=None,
name_before="team",
parent_id_before=None,
position_before="a0",
)
outcome = await revert_service._reinsert_folder_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
from app.db import Folder
assert outcome.status == "ok"
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
assert len(added_folders) == 1
assert added_folders[0].name == "team"
assert revision.folder_id == added_folders[0].id
# ---------------------------------------------------------------------------
# mkdir revert (DELETE folder)
# ---------------------------------------------------------------------------
class TestMkdirRevert:
@pytest.mark.asyncio
async def test_deletes_empty_folder(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# Both the doc-existence check and the child-folder check return None.
session.execute.side_effect = [
_FakeResult(scalar=None), # docs
_FakeResult(scalar=None), # children
_FakeResult(scalar=None), # delete (no return value)
]
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# 3 executes: docs check, children check, delete.
assert session.execute.await_count == 3
@pytest.mark.asyncio
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# First check (docs) returns "row found".
session.execute.return_value = _FakeResult(scalar=1)
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "no longer empty" in outcome.message

View file

@ -0,0 +1,228 @@
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
Earlier versions only handled ``isinstance(chunk.content, str)`` and
silently dropped every other shape (Anthropic typed-block lists,
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
a few providers). These regression tests pin those four shapes plus the
defensive cases (``None`` chunk, mixed types, missing fields).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
class TestStringContent:
def test_plain_string_content_extracts_as_text(self) -> None:
chunk = _FakeChunk(content="hello world")
out = _extract_chunk_parts(chunk)
assert out["text"] == "hello world"
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
def test_empty_string_content_yields_empty_text(self) -> None:
chunk = _FakeChunk(content="")
out = _extract_chunk_parts(chunk)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestListContent:
def test_list_of_text_blocks_concatenates(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Hello world"
assert out["reasoning"] == ""
def test_mixed_text_and_reasoning_blocks(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "reasoning", "reasoning": "Let me think... "},
{"type": "reasoning", "text": "still thinking."},
{"type": "text", "text": "The answer is 42."},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "The answer is 42."
assert out["reasoning"] == "Let me think... still thinking."
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Calling tool..."},
{
"type": "tool_call_chunk",
"id": "call_123",
"name": "make_widget",
"args": '{"color":"red"}',
},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Calling tool..."
assert out["reasoning"] == ""
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "call_123"
assert out["tool_call_chunks"][0]["name"] == "make_widget"
def test_tool_use_blocks_also_extracted(self) -> None:
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
chunk = _FakeChunk(
content=[
{
"type": "tool_use",
"id": "call_xyz",
"name": "search",
},
]
)
out = _extract_chunk_parts(chunk)
assert out["tool_call_chunks"] == [
{"type": "tool_use", "id": "call_xyz", "name": "search"}
]
def test_unknown_block_types_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "image_url", "url": "https://example.com/x.png"},
{"type": "text", "text": "ok"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "ok"
def test_blocks_without_text_field_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text"}, # no text/content key
{"type": "text", "text": "kept"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "kept"
class TestAdditionalKwargsReasoning:
def test_reasoning_content_in_additional_kwargs(self) -> None:
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
chunk = _FakeChunk(
content="visible answer",
additional_kwargs={"reasoning_content": "internal monologue"},
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "visible answer"
assert out["reasoning"] == "internal monologue"
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
chunk = _FakeChunk(
content=[{"type": "reasoning", "text": "from blocks. "}],
additional_kwargs={"reasoning_content": "from kwargs."},
)
out = _extract_chunk_parts(chunk)
assert out["reasoning"] == "from blocks. from kwargs."
class TestToolCallChunksAttribute:
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
self,
) -> None:
chunk = _FakeChunk(
content="streaming text",
tool_call_chunks=[
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
],
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "streaming text"
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "tc-9"
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
chunk = _FakeChunk(
content=[
{
"type": "tool_call_chunk",
"id": "from-block",
"name": "x",
}
],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
)
out = _extract_chunk_parts(chunk)
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
assert ids == ["from-block", "from-attr"]
class TestDefensive:
@pytest.mark.parametrize(
"chunk_value",
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
)
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
out = _extract_chunk_parts(chunk_value)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestIdlessContinuationChunks:
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
tool call carries id+name; later chunks for the same call have
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
argument streaming relies on those idless continuation chunks
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
chunk-emission loop can still route them by ``index``.
"""
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
chunk = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out = _extract_chunk_parts(chunk)
assert len(out["tool_call_chunks"]) == 1
tcc = out["tool_call_chunks"][0]
assert tcc.get("id") is None
assert tcc.get("name") is None
assert tcc.get("args") == '_path":"/x"}'
assert tcc.get("index") == 0
def test_first_then_idless_sequence_preserves_index(self) -> None:
"""Both chunks for the same call share an ``index`` key — the
index-routing loop in ``stream_new_chat`` depends on it."""
first = _FakeChunk(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
]
)
cont = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out_first = _extract_chunk_parts(first)
out_cont = _extract_chunk_parts(cont)
assert out_first["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0].get("id") is None

View file

@ -0,0 +1,527 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk``."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class _FakeToolMessage:
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
content: Any
tool_call_id: str | None = None
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version # unused, contract-compatible
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
def _model_stream(
*,
text: str = "",
reasoning: str = "",
tool_call_chunks: list[dict[str, Any]] | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
return (
{
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
tool_call_chunks=list(tool_call_chunks or []),
)
},
# reasoning piggybacks via additional_kwargs path; if needed,
# override content to a typed-block list. Most tests just check
# tool_call_chunks routing so this is fine.
}
if not reasoning
else {
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
additional_kwargs={"reasoning_content": reasoning},
tool_call_chunks=list(tool_call_chunks or []),
)
},
}
)
def _tool_start(
*,
name: str,
run_id: str,
input_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": input_payload or {}},
}
def _tool_end(
*,
name: str,
run_id: str,
tool_call_id: str | None = None,
output: Any = "ok",
) -> dict[str, Any]:
return {
"event": "on_tool_end",
"name": name,
"run_id": run_id,
"data": {
"output": _FakeToolMessage(
content=json.dumps(output) if not isinstance(output, str) else output,
tool_call_id=tool_call_id,
)
},
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix="thinking"
):
sse_lines.append(sse)
parsed: list[dict[str, Any]] = []
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
parsed.append(json.loads(body))
except json.JSONDecodeError:
continue
return parsed
def _types(payloads: list[dict[str, Any]]) -> list[str]:
return [p.get("type", "?") for p in payloads]
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
return [p for p in payloads if p.get("type") == type_name]
# ---------------------------------------------------------------------------
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
# ---------------------------------------------------------------------------
class TestLegacyMatch:
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
chunks: list[dict[str, Any]] = [
{"id": "x1", "name": "ls"},
{"id": "y1", "name": "write_file"},
]
runs: dict[str, str] = {}
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
assert result == "y1"
assert chunks == [{"id": "x1", "name": "ls"}]
assert runs == {"run-1": "y1"}
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
runs: dict[str, str] = {}
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
assert out == "anon"
assert chunks == []
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
runs: dict[str, str] = {}
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
assert chunks == [{"id": None, "name": None}]
assert runs == {}
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
],
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
available = _of_type(payloads, "tool-input-available")
output = _of_type(payloads, "tool-output-available")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
assert starts[0]["toolName"] == "write_file"
assert starts[0]["langchainToolCallId"] == "lc-1"
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
assert all(d["toolCallId"] == "lc-1" for d in deltas)
assert len(available) == 1
assert available[0]["toolCallId"] == "lc-1"
assert len(output) == 1
assert output[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
]
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": "}", "index": 0},
{"id": None, "name": None, "args": "}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
output = _of_type(payloads, "tool-output-available")
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
for d in deltas:
by_id[d["toolCallId"]].append(d["inputTextDelta"])
assert by_id["lc-A"] == ['{"a":1', "}"]
assert by_id["lc-B"] == ['{"b":2', "}"]
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
]
payloads = await _drain(events)
relevant = [
p
for p in payloads
if p.get("type")
in {"tool-input-start", "tool-input-available", "tool-output-available"}
]
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
events = [
_model_stream(text="Working on it"),
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
text_end_idx = types.index("text-end")
start_idx = types.index("tool-input-start")
assert text_end_idx < start_idx
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
events = [
_model_stream(
text="I'll update it",
tool_call_chunks=[
{
"id": "lc-1",
"name": "write_file",
"args": '{"file_path":"/x"}',
"index": 0,
}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
assert types.index("text-start") < types.index("text-delta")
assert types.index("text-delta") < types.index("text-end")
assert types.index("text-end") < types.index("tool-input-start")
assert types.index("tool-input-start") < types.index("tool-input-delta")
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-1", input_payload={}),
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-2", input_payload={}),
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
# Two distinct lc ids, each its own card.
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
# Each tool-output-available landed on its respective card.
output = _of_type(payloads, "tool-output-available")
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-start") == []
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
# match path can pop it at on_tool_start.
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"

View file

@ -0,0 +1,66 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function BlogPostLoading() {
return (
<div className="min-h-screen relative pt-20">
<div className="max-w-3xl mx-auto px-6 lg:px-10 pt-10 pb-20">
{/* Breadcrumb */}
<div className="flex items-center gap-2 mb-8">
<Skeleton className="h-4 w-10" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-10" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-40" />
</div>
{/* Tags */}
<div className="flex flex-wrap gap-2 mb-4">
<Skeleton className="h-6 w-16 rounded-full" />
<Skeleton className="h-6 w-20 rounded-full" />
</div>
{/* Title */}
<div className="space-y-3 mb-6">
<Skeleton className="h-10 w-full" />
<Skeleton className="h-10 w-4/5" />
</div>
{/* Description */}
<Skeleton className="h-5 w-full mb-2" />
<Skeleton className="h-5 w-3/4 mb-8" />
{/* Author + date */}
<div className="flex items-center gap-3 mb-10">
<Skeleton className="h-10 w-10 rounded-full" />
<div className="space-y-1.5">
<Skeleton className="h-4 w-32" />
<Skeleton className="h-3 w-24" />
</div>
</div>
{/* Cover image */}
<Skeleton className="w-full aspect-video rounded-xl mb-10" />
{/* Article body paragraphs */}
{Array.from({ length: 5 }).map((_, i) => (
<div key={i} className="space-y-2 mb-6">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-4/5" />
</div>
))}
{/* Sub-heading */}
<Skeleton className="h-7 w-56 mt-8 mb-4" />
{Array.from({ length: 3 }).map((_, i) => (
<div key={i} className="space-y-2 mb-6">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-11/12" />
<Skeleton className="h-4 w-3/4" />
</div>
))}
</div>
</div>
);
}

View file

@ -3,7 +3,7 @@
import { format } from "date-fns";
import FuzzySearch from "fuzzy-search";
import Link from "next/link";
import { useEffect, useMemo, useState } from "react";
import { useMemo, useState } from "react";
import { Container } from "@/components/container";
import type { BlogEntry } from "./page";
@ -127,17 +127,13 @@ function MagazineSearchGrid({
[allBlogs]
);
const [results, setResults] = useState(allBlogs);
useEffect(() => {
setResults(searcher.search(search));
}, [search, searcher]);
const gridItems = useMemo(() => {
const results = search.trim() ? searcher.search(search) : allBlogs;
if (search.trim()) {
return results;
}
return results.filter((b) => b.slug !== featuredSlug);
}, [results, search, featuredSlug]);
}, [search, searcher, allBlogs, featuredSlug]);
return (
<section aria-labelledby="archive-heading">

View file

@ -0,0 +1,50 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function BlogIndexLoading() {
return (
<div className="relative overflow-hidden bg-neutral-50 px-4 pt-20 md:px-8 dark:bg-neutral-950">
<div className="mx-auto max-w-6xl pt-12 pb-24 md:pt-20">
{/* Header */}
<div className="mb-10 md:mb-14">
<Skeleton className="h-10 w-24 rounded-md" />
</div>
{/* Featured post skeleton */}
<div className="mb-14 overflow-hidden rounded-3xl border border-neutral-200/80 dark:border-neutral-800">
<Skeleton className="aspect-[2.4/1] min-h-[220px] w-full rounded-none" />
<div className="p-6 md:p-8 space-y-3">
<Skeleton className="h-5 w-24 rounded-full" />
<Skeleton className="h-8 w-3/4" />
<Skeleton className="h-4 w-full max-w-lg" />
<div className="flex items-center gap-3 pt-2">
<Skeleton className="h-8 w-8 rounded-full" />
<Skeleton className="h-4 w-28" />
<Skeleton className="h-4 w-20" />
</div>
</div>
</div>
{/* Search bar skeleton */}
<div className="mb-10">
<Skeleton className="h-11 w-full max-w-md rounded-full" />
</div>
{/* Grid of article cards */}
<div className="grid gap-8 md:grid-cols-2 lg:grid-cols-3">
{Array.from({ length: 6 }).map((_, i) => (
<div key={i} className="space-y-3">
<Skeleton className="aspect-video w-full rounded-2xl" />
<Skeleton className="h-5 w-3/4" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-5/6" />
<div className="flex items-center gap-2 pt-1">
<Skeleton className="h-6 w-6 rounded-full" />
<Skeleton className="h-4 w-24" />
</div>
</div>
))}
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,63 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function ChangelogLoading() {
return (
<div className="min-h-screen relative pt-20">
{/* Header */}
<div className="border-b border-border/50">
<div className="max-w-5xl mx-auto relative">
<div className="p-6 flex items-center justify-between">
<div>
{/* Breadcrumb */}
<div className="flex items-center gap-2 mb-4">
<Skeleton className="h-4 w-10" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-20" />
</div>
<Skeleton className="h-10 w-48 mb-2" />
<Skeleton className="h-4 w-80" />
</div>
</div>
</div>
</div>
{/* Timeline */}
<div className="max-w-5xl mx-auto px-6 lg:px-10 pt-10 pb-20">
<div className="relative">
{Array.from({ length: 3 }).map((_, i) => (
<div key={i} className="relative flex flex-col md:flex-row gap-y-6 mb-10">
{/* Left: date + version */}
<div className="md:w-48 flex-shrink-0">
<Skeleton className="h-4 w-24 mb-3" />
<Skeleton className="h-12 w-12 rounded-xl" />
</div>
{/* Right: content */}
<div className="flex-1 md:pl-8 relative pb-10">
<div className="space-y-4">
{/* Title */}
<Skeleton className="h-7 w-2/3" />
{/* Tags */}
<div className="flex gap-2">
<Skeleton className="h-6 w-16 rounded-full" />
<Skeleton className="h-6 w-20 rounded-full" />
</div>
{/* Body paragraphs */}
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</div>
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-5/6" />
</div>
</div>
</div>
</div>
))}
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,65 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function FreeModelLoading() {
return (
<>
{/* Chat area skeleton - fills viewport */}
<div className="h-full flex flex-col">
{/* Chat header */}
<div className="flex items-center gap-3 border-b px-4 py-3">
<Skeleton className="h-8 w-8 rounded-full" />
<Skeleton className="h-5 w-40" />
</div>
{/* Chat messages area */}
<div className="flex-1 flex flex-col justify-end gap-4 px-4 py-6">
<div className="flex justify-end">
<Skeleton className="h-10 w-56 rounded-2xl" />
</div>
<div className="space-y-2 max-w-lg">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-4/5" />
<Skeleton className="h-4 w-3/4" />
</div>
</div>
{/* Input bar */}
<div className="border-t px-4 py-3">
<Skeleton className="h-12 w-full rounded-xl" />
</div>
</div>
{/* SEO section skeleton */}
<div className="border-t bg-background">
<div className="container mx-auto px-4 py-10 max-w-3xl">
{/* Breadcrumb */}
<div className="flex items-center gap-2 mb-6">
<Skeleton className="h-4 w-10" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-24" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-32" />
</div>
<Skeleton className="h-7 w-3/4 mb-2" />
<Skeleton className="h-4 w-full mb-1" />
<Skeleton className="h-4 w-2/3 mb-8" />
<div className="my-8 h-px bg-border" />
{/* FAQ skeleton */}
<Skeleton className="h-6 w-64 mb-4" />
<div className="flex flex-col gap-3">
{Array.from({ length: 4 }).map((_, i) => (
<div key={i} className="rounded-lg border bg-card p-4 space-y-2">
<Skeleton className="h-4 w-3/4" />
<Skeleton className="h-3 w-full" />
<Skeleton className="h-3 w-5/6" />
</div>
))}
</div>
</div>
</div>
</>
);
}

View file

@ -0,0 +1,60 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function FreeChatLoading() {
return (
<div className="min-h-screen pt-20">
<article className="container mx-auto px-4 pb-20">
{/* Breadcrumb */}
<div className="flex items-center gap-2 mb-8">
<Skeleton className="h-4 w-10" />
<Skeleton className="h-4 w-3" />
<Skeleton className="h-4 w-24" />
</div>
{/* Hero section */}
<section className="mt-8 text-center max-w-3xl mx-auto space-y-4">
<Skeleton className="h-12 w-3/4 mx-auto" />
<Skeleton className="h-12 w-2/3 mx-auto" />
<Skeleton className="h-5 w-full max-w-lg mx-auto" />
<Skeleton className="h-5 w-4/5 max-w-lg mx-auto" />
<div className="flex flex-wrap items-center justify-center gap-3 mt-6">
{Array.from({ length: 4 }).map((_, i) => (
<Skeleton key={i} className="h-8 w-28 rounded-full" />
))}
</div>
</section>
<div className="my-12 max-w-4xl mx-auto h-px bg-border" />
{/* Model table */}
<section className="max-w-4xl mx-auto">
<Skeleton className="h-7 w-64 mb-2" />
<Skeleton className="h-4 w-80 mb-6" />
<div className="overflow-hidden rounded-lg border">
{/* Table header */}
<div className="flex gap-4 px-4 py-3 bg-muted/50 border-b">
<Skeleton className="h-4 w-[45%]" />
<Skeleton className="h-4 w-24" />
<Skeleton className="h-4 w-16" />
<Skeleton className="h-4 w-20" />
</div>
{/* Table rows */}
{Array.from({ length: 8 }).map((_, i) => (
<div key={i} className="flex items-center gap-4 px-4 py-3 border-b last:border-0">
<div className="flex-1 space-y-1.5">
<Skeleton className="h-4 w-40" />
<Skeleton className="h-3 w-24" />
</div>
<Skeleton className="h-4 w-24" />
<Skeleton className="h-6 w-14 rounded-full" />
<Skeleton className="h-8 w-20 rounded-md" />
</div>
))}
</div>
</section>
</article>
</div>
);
}

View file

@ -0,0 +1,55 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function DocsLoading() {
return (
<div className="flex flex-1 flex-col gap-4 p-6 max-w-4xl mx-auto w-full">
{/* Title */}
<Skeleton className="h-9 w-64" />
{/* Description */}
<Skeleton className="h-5 w-full max-w-md" />
<div className="mt-4 space-y-8">
{/* Paragraph block 1 */}
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</div>
{/* Sub-heading */}
<Skeleton className="h-7 w-48" />
{/* Paragraph block 2 */}
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-5/6" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-2/3" />
</div>
{/* Code block placeholder */}
<Skeleton className="h-28 w-full rounded-lg" />
{/* Sub-heading */}
<Skeleton className="h-7 w-56" />
{/* List items */}
<div className="space-y-3">
{Array.from({ length: 4 }).map((_, i) => (
<div key={i} className="flex items-start gap-3">
<Skeleton className="mt-1 h-3 w-3 shrink-0 rounded-full" />
<Skeleton className="h-4 w-full max-w-lg" />
</div>
))}
</div>
{/* Paragraph block 3 */}
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-4/5" />
</div>
</div>
</div>
);
}

View file

@ -10,21 +10,11 @@ import type { Document } from "@/contracts/types/document.types";
export const mentionedDocumentsAtom = atom<Pick<Document, "id" | "title" | "document_type">[]>([]);
/**
* Atom to store documents selected via the sidebar checkboxes / row clicks.
* These are NOT inserted as chips the composer shows a count badge instead.
*/
export const sidebarSelectedDocumentsAtom = atom<
Pick<Document, "id" | "title" | "document_type">[]
>([]);
/**
* Derived read-only atom that merges @-mention chips and sidebar selections
* into a single deduplicated set of document IDs for the backend.
* Derived read-only atom that maps deduplicated mentioned docs
* into backend payload fields.
*/
export const mentionedDocumentIdsAtom = atom((get) => {
const chipDocs = get(mentionedDocumentsAtom);
const sidebarDocs = get(sidebarSelectedDocumentsAtom);
const allDocs = [...chipDocs, ...sidebarDocs];
const allDocs = get(mentionedDocumentsAtom);
const seen = new Set<string>();
const deduped = allDocs.filter((d) => {
const key = `${d.document_type}:${d.id}`;

View file

@ -17,16 +17,12 @@ import {
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator";
import { getToolIcon } from "@/contracts/enums/toolIcons";
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons";
import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { formatRelativeDate } from "@/lib/format-date";
import { cn } from "@/lib/utils";
function formatToolName(name: string): string {
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
}
interface ActionLogItemProps {
action: AgentAction;
threadId: number;
@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt
const hasError = action.error !== null && action.error !== undefined;
const Icon = getToolIcon(action.tool_name);
const displayName = formatToolName(action.tool_name);
const displayName = getToolDisplayName(action.tool_name);
const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null;
const truncatedArgs =

View file

@ -1,9 +1,9 @@
"use client";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useQueryClient } from "@tanstack/react-query";
import { useAtom, useAtomValue } from "jotai";
import { Activity, RefreshCcw } from "lucide-react";
import { useCallback, useMemo } from "react";
import { useCallback } from "react";
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { Badge } from "@/components/ui/badge";
@ -17,15 +17,12 @@ import {
SheetTitle,
} from "@/components/ui/sheet";
import { Skeleton } from "@/components/ui/skeleton";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import {
agentActionsQueryKey,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { ActionLogItem } from "./action-log-item";
const ACTION_LOG_PAGE_SIZE = 50;
function actionLogQueryKey(threadId: number) {
return ["agent-actions", threadId] as const;
}
function EmptyState() {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
@ -85,25 +82,17 @@ export function ActionLogSheet() {
const threadId = state.threadId;
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"],
queryFn: () =>
agentActionsApiService.listForThread(threadId as number, {
page: 0,
pageSize: ACTION_LOG_PAGE_SIZE,
}),
enabled: state.open && threadId !== null && actionLogEnabled,
staleTime: 15 * 1000,
});
const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
threadId,
{ enabled: state.open && actionLogEnabled }
);
const handleRevertSuccess = useCallback(() => {
if (threadId !== null) {
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) });
queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
}
}, [queryClient, threadId]);
const items = useMemo(() => data?.items ?? [], [data]);
return (
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
<SheetContent

View file

@ -33,6 +33,8 @@ import {
useAllCitationMetadata,
} from "@/components/assistant-ui/citation-metadata-context";
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button";
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => {
<MessagePrimitive.Parts
components={{
Text: MarkdownText,
Reasoning: ReasoningMessagePart,
tools: {
by_name: {
generate_report: GenerateReportToolUI,
@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => {
const isLast = useAuiState((s) => s.message.isLast);
const aui = useAui();
const api = useElectronAPI();
// Surface the persisted ``chat_turn_id`` so the per-turn revert
// affordance can scope to just this message's actions. Streamed
// turns get their id once the assistant message is hydrated/finalised.
const chatTurnId = useAuiState(({ message }) => {
const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined;
return meta?.custom?.chatTurnId ?? null;
});
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => {
</TooltipIconButton>
)}
<MessageInfoDropdown />
<div className="ml-auto">
<RevertTurnButton chatTurnId={chatTurnId} />
</div>
</ActionBarPrimitive.Root>
);
};

View file

@ -0,0 +1,106 @@
"use client";
/**
* Confirmation dialog shown when the user edits a message that has
* reversible downstream actions. Three buttons:
*
* "Revert all & resubmit" POST regenerate with revert_actions=true
* "Continue without revert" POST regenerate with revert_actions=false
* "Cancel" abort the edit entirely
*
* The dialog is auto-skipped when zero reversible downstream actions
* exist (the caller checks first via ``downstreamReversibleCount``).
*/
import { useEffect, useRef, useState } from "react";
import {
AlertDialog,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
export type EditMessageDialogChoice = "revert" | "continue" | "cancel";
export interface EditMessageDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
downstreamReversibleCount: number;
downstreamTotalCount: number;
onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>;
}
export function EditMessageDialog({
open,
onOpenChange,
downstreamReversibleCount,
downstreamTotalCount,
onChoose,
}: EditMessageDialogProps) {
const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null);
// The parent's ``handleEditDialogChoice`` calls
// ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``.
// That collapses the dialog (Radix unmounts it) while ``onChoose``
// is still awaiting the long-running stream. Without this guard,
// the ``finally { setBusy(null) }`` below ran after unmount and
// produced a "state update on unmounted component" dev warning.
const mountedRef = useRef(true);
useEffect(() => {
mountedRef.current = true;
return () => {
mountedRef.current = false;
};
}, []);
const handle = async (choice: EditMessageDialogChoice) => {
setBusy(choice);
try {
await onChoose(choice);
} finally {
if (mountedRef.current) {
setBusy(null);
}
}
};
return (
<AlertDialog open={open} onOpenChange={onOpenChange}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Edit this message?</AlertDialogTitle>
<AlertDialogDescription>
This edit drops {downstreamTotalCount} downstream message
{downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "}
action
{downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can
be rolled back. Pick how to handle them before regenerating.
</AlertDialogDescription>
</AlertDialogHeader>
<div className="grid gap-2">
<Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}>
{busy === "revert"
? "Reverting & resubmitting…"
: `Revert ${downstreamReversibleCount} action${
downstreamReversibleCount === 1 ? "" : "s"
} & resubmit`}
</Button>
<Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}>
{busy === "continue" ? "Resubmitting…" : "Continue without reverting"}
</Button>
</div>
<AlertDialogFooter className="sm:justify-start">
<AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}>
Cancel
</AlertDialogCancel>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
);
}

View file

@ -11,25 +11,14 @@ import {
useRef,
useState,
} from "react";
import { flushSync } from "react-dom";
import { createRoot } from "react-dom/client";
import { renderToStaticMarkup } from "react-dom/server";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { Document } from "@/contracts/types/document.types";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { cn } from "@/lib/utils";
// Render a React element to an HTML string on the client without pulling
// `react-dom/server` into the bundle. `createRoot` + `flushSync` use the
// same `react-dom` package React itself imports, so this adds zero new
// runtime weight.
function renderElementToHTML(element: ReactElement): string {
const container = document.createElement("div");
const root = createRoot(container);
flushSync(() => {
root.render(element);
});
const html = container.innerHTML;
root.unmount();
return html;
return renderToStaticMarkup(element);
}
export interface MentionedDocument {
@ -44,7 +33,10 @@ export interface InlineMentionEditorRef {
setText: (text: string) => void;
getText: () => string;
getMentionedDocuments: () => MentionedDocument[];
insertDocumentChip: (doc: Pick<Document, "id" | "title" | "document_type">) => void;
insertDocumentChip: (
doc: Pick<Document, "id" | "title" | "document_type">,
options?: { removeTriggerText?: boolean }
) => void;
removeDocumentChip: (docId: number, docType?: string) => void;
setDocumentChipStatus: (
docId: number,
@ -66,7 +58,6 @@ interface InlineMentionEditorProps {
onKeyDown?: (e: React.KeyboardEvent) => void;
disabled?: boolean;
className?: string;
initialDocuments?: MentionedDocument[];
initialText?: string;
}
@ -118,7 +109,6 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
onKeyDown,
disabled = false,
className,
initialDocuments = [],
initialText,
},
ref
@ -126,18 +116,49 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
const editorRef = useRef<HTMLDivElement>(null);
const [isEmpty, setIsEmpty] = useState(true);
const [mentionedDocs, setMentionedDocs] = useState<Map<string, MentionedDocument>>(
() => new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d]))
() => new Map()
);
const isComposingRef = useRef(false);
const lastSelectionRangeRef = useRef<Range | null>(null);
const isRangeInsideEditor = useCallback((range: Range | null): range is Range => {
if (!range || !editorRef.current) return false;
return (
editorRef.current.contains(range.startContainer) &&
editorRef.current.contains(range.endContainer)
);
}, []);
const isSelectionInsideEditor = useCallback(
(selection: Selection | null): selection is Selection => {
if (!selection || selection.rangeCount === 0 || !editorRef.current) return false;
const range = selection.getRangeAt(0);
return isRangeInsideEditor(range);
},
[isRangeInsideEditor]
);
const rememberSelection = useCallback(() => {
const selection = window.getSelection();
if (!isSelectionInsideEditor(selection)) return;
lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange();
}, [isSelectionInsideEditor]);
const restoreRememberedSelection = useCallback((): Selection | null => {
const selection = window.getSelection();
if (!selection) return null;
if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null;
selection.removeAllRanges();
selection.addRange(lastSelectionRangeRef.current.cloneRange());
return selection;
}, [isRangeInsideEditor]);
// Sync initial documents
useEffect(() => {
if (initialDocuments.length > 0) {
setMentionedDocs(
new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d]))
);
}
}, [initialDocuments]);
const handleSelectionChange = () => {
if (document.activeElement !== editorRef.current) return;
rememberSelection();
};
document.addEventListener("selectionchange", handleSelectionChange);
return () => document.removeEventListener("selectionchange", handleSelectionChange);
}, [rememberSelection]);
useEffect(() => {
if (!initialText || !editorRef.current) return;
@ -145,7 +166,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
editorRef.current.appendChild(document.createElement("br"));
editorRef.current.appendChild(document.createElement("br"));
setIsEmpty(false);
onChange?.(initialText, Array.from(mentionedDocs.values()));
onChange?.(initialText, []);
editorRef.current.focus();
const sel = window.getSelection();
const range = document.createRange();
@ -157,7 +178,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
range.insertNode(anchor);
anchor.scrollIntoView({ block: "end" });
anchor.remove();
}, [initialText]); // eslint-disable-line react-hooks/exhaustive-deps
}, [initialText, onChange]);
// Focus at the end of the editor
const focusAtEnd = useCallback(() => {
@ -211,6 +232,19 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
return Array.from(mentionedDocs.values());
}, [mentionedDocs]);
const syncEditorState = useCallback(
(docsOverride?: Map<string, MentionedDocument>) => {
const docs = docsOverride
? Array.from(docsOverride.values())
: Array.from(mentionedDocs.values());
const text = getText();
const empty = text.length === 0 && docs.length === 0;
setIsEmpty(empty);
onChange?.(text, docs);
},
[getText, mentionedDocs, onChange]
);
// Create a chip element for a document
const createChipElement = useCallback(
(doc: MentionedDocument): HTMLSpanElement => {
@ -246,10 +280,11 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
e.preventDefault();
e.stopPropagation();
chip.remove();
const docKey = `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
const docKey = getMentionDocKey(doc);
setMentionedDocs((prev) => {
const next = new Map(prev);
next.delete(docKey);
syncEditorState(next);
return next;
});
onDocumentRemove?.(doc.id, doc.document_type);
@ -294,13 +329,17 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
return chip;
},
[focusAtEnd, onDocumentRemove]
[focusAtEnd, onDocumentRemove, syncEditorState]
);
// Insert a document chip at the current cursor position
const insertDocumentChip = useCallback(
(doc: Pick<Document, "id" | "title" | "document_type">) => {
(
doc: Pick<Document, "id" | "title" | "document_type">,
options?: { removeTriggerText?: boolean }
) => {
if (!editorRef.current) return;
const removeTriggerText = options?.removeTriggerText ?? true;
// Validate required fields for type safety
if (typeof doc.id !== "number" || typeof doc.title !== "string") {
@ -315,25 +354,51 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
};
// Add to mentioned docs map using unique key
const docKey = `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
const docKey = getMentionDocKey(doc);
setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc));
const nextDocs = new Map(mentionedDocs);
nextDocs.set(docKey, mentionDoc);
// Find and remove the @query text
const selection = window.getSelection();
if (!selection || selection.rangeCount === 0) {
// No selection, just append
const hasActiveSelection = isSelectionInsideEditor(selection);
const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection();
if (
!resolvedSelection ||
resolvedSelection.rangeCount === 0 ||
!isSelectionInsideEditor(resolvedSelection)
) {
// No valid in-editor selection: deterministically insert at end.
editorRef.current.focus();
const endSelection = window.getSelection();
if (!endSelection) return;
const endRange = document.createRange();
endRange.selectNodeContents(editorRef.current);
endRange.collapse(false);
endSelection.removeAllRanges();
endSelection.addRange(endRange);
const chip = createChipElement(mentionDoc);
editorRef.current.appendChild(chip);
editorRef.current.appendChild(document.createTextNode(" "));
focusAtEnd();
endRange.insertNode(chip);
endRange.setStartAfter(chip);
endRange.collapse(true);
const space = document.createTextNode(" ");
endRange.insertNode(space);
endRange.setStartAfter(space);
endRange.collapse(true);
endSelection.removeAllRanges();
endSelection.addRange(endRange);
syncEditorState(nextDocs);
rememberSelection();
return;
}
// Find the @ symbol before the cursor and remove it along with any query text
const range = selection.getRangeAt(0);
const range = resolvedSelection.getRangeAt(0);
const textNode = range.startContainer;
if (textNode.nodeType === Node.TEXT_NODE) {
if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) {
const text = textNode.textContent || "";
const cursorPos = range.startOffset;
@ -369,8 +434,9 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
const newRange = document.createRange();
newRange.setStart(afterNode, 1);
newRange.collapse(true);
selection.removeAllRanges();
selection.addRange(newRange);
resolvedSelection.removeAllRanges();
resolvedSelection.addRange(newRange);
rememberSelection();
}
} else {
// No @ found, just insert at cursor
@ -384,48 +450,56 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
range.insertNode(space);
range.setStartAfter(space);
range.collapse(true);
resolvedSelection.removeAllRanges();
resolvedSelection.addRange(range);
rememberSelection();
}
} else {
// Not in a text node, append to editor
// Either explicit non-trigger insertion or no @query present.
const chip = createChipElement(mentionDoc);
editorRef.current.appendChild(chip);
editorRef.current.appendChild(document.createTextNode(" "));
focusAtEnd();
range.insertNode(chip);
range.setStartAfter(chip);
range.collapse(true);
const space = document.createTextNode(" ");
range.insertNode(space);
range.setStartAfter(space);
range.collapse(true);
resolvedSelection.removeAllRanges();
resolvedSelection.addRange(range);
rememberSelection();
}
// Update empty state
setIsEmpty(false);
// Trigger onChange
if (onChange) {
setTimeout(() => {
onChange(getText(), getMentionedDocuments());
}, 0);
}
syncEditorState(nextDocs);
},
[createChipElement, focusAtEnd, getText, getMentionedDocuments, onChange]
[
createChipElement,
isSelectionInsideEditor,
mentionedDocs,
rememberSelection,
restoreRememberedSelection,
syncEditorState,
]
);
// Clear the editor
const clear = useCallback(() => {
if (editorRef.current) {
editorRef.current.innerHTML = "";
setIsEmpty(true);
setMentionedDocs(new Map());
const emptyDocs = new Map<string, MentionedDocument>();
setMentionedDocs(emptyDocs);
syncEditorState(emptyDocs);
}
}, []);
}, [syncEditorState]);
// Replace editor content with plain text and place cursor at end
const setText = useCallback(
(text: string) => {
if (!editorRef.current) return;
editorRef.current.innerText = text;
const empty = text.length === 0;
setIsEmpty(empty);
onChange?.(text, Array.from(mentionedDocs.values()));
syncEditorState();
focusAtEnd();
},
[focusAtEnd, onChange, mentionedDocs]
[focusAtEnd, syncEditorState]
);
const setDocumentChipStatus = useCallback(
@ -473,7 +547,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
const removeDocumentChip = useCallback(
(docId: number, docType?: string) => {
if (!editorRef.current) return;
const chipKey = `${docType ?? "UNKNOWN"}:${docId}`;
const chipKey = getMentionDocKey({ id: docId, document_type: docType });
const chips = editorRef.current.querySelectorAll<HTMLSpanElement>(
`span[${CHIP_DATA_ATTR}="true"]`
);
@ -486,14 +560,11 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
setMentionedDocs((prev) => {
const next = new Map(prev);
next.delete(chipKey);
syncEditorState(next);
return next;
});
const text = getText();
const empty = text.length === 0 && mentionedDocs.size <= 1;
setIsEmpty(empty);
},
[getText, mentionedDocs.size]
[syncEditorState]
);
// Expose methods via ref
@ -594,6 +665,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
// Notify parent of change
onChange?.(text, Array.from(mentionedDocs.values()));
rememberSelection();
}, [
getText,
mentionedDocs,
@ -602,6 +674,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
onMentionClose,
onActionTrigger,
onActionClose,
rememberSelection,
]);
// Handle keydown
@ -639,10 +712,14 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
const chipDocType = getChipDocType(prevSibling);
if (chipId !== null) {
prevSibling.remove();
const chipKey = `${chipDocType}:${chipId}`;
const chipKey = getMentionDocKey({
id: chipId,
document_type: chipDocType,
});
setMentionedDocs((prev) => {
const next = new Map(prev);
next.delete(chipKey);
syncEditorState(next);
return next;
});
// Notify parent that a document was removed
@ -676,10 +753,14 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
const chipDocType = getChipDocType(prevChild);
if (chipId !== null) {
prevChild.remove();
const chipKey = `${chipDocType}:${chipId}`;
const chipKey = getMentionDocKey({
id: chipId,
document_type: chipDocType,
});
setMentionedDocs((prev) => {
const next = new Map(prev);
next.delete(chipKey);
syncEditorState(next);
return next;
});
// Notify parent that a document was removed
@ -691,7 +772,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
}
}
},
[onKeyDown, onSubmit, onDocumentRemove, onMentionClose]
[onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState]
);
// Handle paste - strip formatting
@ -713,7 +794,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
return (
<div className="relative w-full">
{/** biome-ignore lint/a11y/useSemanticElements: <not important> */}
{/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */}
<div
ref={editorRef}
contentEditable={!disabled}
@ -724,6 +805,9 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
onPaste={handlePaste}
onCompositionStart={handleCompositionStart}
onCompositionEnd={handleCompositionEnd}
onKeyUp={rememberSelection}
onMouseUp={rememberSelection}
onBlur={rememberSelection}
className={cn(
"min-h-[24px] max-h-32 overflow-y-auto",
"text-sm outline-none",
@ -733,9 +817,6 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
)}
style={{ wordBreak: "break-word" }}
data-placeholder={placeholder}
aria-label="Message input with inline mentions"
role="textbox"
aria-multiline="true"
/>
{/* Placeholder with fade animation on change */}
{isEmpty && (

View file

@ -0,0 +1,81 @@
"use client";
import type { ReasoningMessagePartComponent } from "@assistant-ui/react";
import { ChevronRightIcon } from "lucide-react";
import { useEffect, useMemo, useState } from "react";
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
import { cn } from "@/lib/utils";
/**
* Renders the structured `reasoning` part emitted by the backend's
* stream-parity v2 path (A1).
*
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
* - collapsed by default;
* - auto-expanded while the part is still `running`;
* - auto-collapsed once status flips to `complete`.
*
* The component is registered via the `Reasoning` slot on
* `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the
* exact ordinal position of the reasoning block in the message content
* array (i.e. above the assistant text that follows it).
*/
export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => {
const isRunning = status?.type === "running";
const [isOpen, setIsOpen] = useState(() => isRunning);
useEffect(() => {
if (isRunning) {
setIsOpen(true);
} else if (status?.type === "complete") {
setIsOpen(false);
}
}, [isRunning, status?.type]);
const headerLabel = useMemo(() => {
if (isRunning) return "Thinking";
if (status?.type === "incomplete") return "Thinking interrupted";
return "Thought";
}, [isRunning, status?.type]);
if (!text || text.length === 0) {
if (!isRunning) return null;
}
return (
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
<div className="rounded-lg">
<button
type="button"
onClick={() => setIsOpen((prev) => !prev)}
className={cn(
"flex w-full items-center gap-1.5 text-left text-sm transition-colors",
"text-muted-foreground hover:text-foreground"
)}
>
{isRunning ? (
<TextShimmerLoader text={headerLabel} size="sm" />
) : (
<span>{headerLabel}</span>
)}
<ChevronRightIcon
className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")}
/>
</button>
<div
className={cn(
"grid transition-[grid-template-rows] duration-300 ease-out",
isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]"
)}
>
<div className="overflow-hidden">
<div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word">
{text}
</div>
</div>
</div>
</div>
</div>
);
};

View file

@ -0,0 +1,213 @@
"use client";
/**
* "Revert turn" button rendered at the bottom of every completed
* assistant turn that has at least one reversible action.
*
* The button reads from the unified ``useAgentActionsQuery`` cache
* (the SAME react-query cache the agent-actions sheet and the inline
* Revert button consume) filtered by ``chat_turn_id``. It shows a
* confirmation dialog summarising "N reversible / M total" and, on
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
*
* The route returns a per-action result list and never collapses the
* batch into a 4xx so we render any failed/not_reversible rows inline
* with their messages.
*/
import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react";
import { toast } from "sonner";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import {
applyRevertTurnResultsToCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import {
agentActionsApiService,
type RevertTurnActionResult,
} from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { cn } from "@/lib/utils";
interface RevertTurnButtonProps {
chatTurnId: string | null | undefined;
}
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
const session = useAtomValue(chatSessionStateAtom);
const threadId = session?.threadId ?? null;
const queryClient = useQueryClient();
const { findByChatTurnId } = useAgentActionsQuery(threadId);
const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false);
const [resultsOpen, setResultsOpen] = useState(false);
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
const reversibleCount = useMemo(
() =>
actions.filter(
(a) =>
a.reversible &&
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
!a.is_revert_action &&
(a.error === null || a.error === undefined)
).length,
[actions]
);
const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
if (!chatTurnId) return null;
if (reversibleCount === 0) return null;
if (!threadId) return null;
const handleRevertTurn = async () => {
setIsReverting(true);
try {
const response = await agentActionsApiService.revertTurn(threadId, chatTurnId);
setResults(response.results);
const revertedEntries = response.results
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
if (revertedEntries.length > 0) {
applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
}
if (response.status === "ok") {
toast.success(
response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.`
);
} else {
// Every "not undone" bucket counts as a failure for the
// user-facing summary. ``skipped`` rows are batch
// artefacts (revert rows themselves) and intentionally
// excluded from the failure tally.
const failureCount =
response.failed + response.not_reversible + (response.permission_denied ?? 0);
toast.warning(
`Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.`
);
setResultsOpen(true);
}
} catch (err) {
if (err instanceof AppError && err.status === 503) {
return;
}
const message =
err instanceof AppError
? err.message
: err instanceof Error
? err.message
: "Failed to revert turn.";
toast.error(message);
} finally {
setIsReverting(false);
setConfirmOpen(false);
}
};
return (
<>
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
<AlertDialogTrigger asChild>
<Button
size="sm"
variant="ghost"
className="text-muted-foreground hover:text-foreground gap-1.5"
onClick={(e) => {
e.stopPropagation();
setConfirmOpen(true);
}}
>
<RotateCcw className="size-3.5" />
<span>Revert turn</span>
<span className="text-xs tabular-nums opacity-70">
{reversibleCount}/{totalCount}
</span>
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert this turn?</AlertDialogTitle>
<AlertDialogDescription>
This will undo {reversibleCount} of {totalCount} action
{totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and
any read-only actions are preserved. Some rows may not be reversible partial success
is normal.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={(e) => {
e.preventDefault();
handleRevertTurn();
}}
disabled={isReverting}
>
{isReverting ? "Reverting…" : "Revert turn"}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
<AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert results</AlertDialogTitle>
<AlertDialogDescription>
Some actions could not be reverted. Review per-row outcomes below.
</AlertDialogDescription>
</AlertDialogHeader>
<ul className="max-h-72 overflow-y-auto space-y-2 text-sm">
{results.map((r) => (
<RevertResultRow key={r.action_id} result={r} />
))}
</ul>
<AlertDialogFooter>
<AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</>
);
}
function RevertResultRow({ result }: { result: RevertTurnActionResult }) {
const isOk = result.status === "reverted" || result.status === "already_reverted";
const Icon = isOk ? CheckIcon : XCircleIcon;
return (
<li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2">
<Icon
className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")}
/>
<div className="min-w-0 flex-1">
<p className="font-medium truncate">
{getToolDisplayName(result.tool_name)}{" "}
<span className="ml-1 text-xs text-muted-foreground">
{result.status.replace(/_/g, " ")}
</span>
</p>
{(result.message || result.error) && (
<p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p>
)}
</div>
</li>
);
}

View file

@ -0,0 +1,27 @@
"use client";
import { makeAssistantDataUI } from "@assistant-ui/react";
/**
* Renders a thin horizontal divider between model steps within a single
* assistant turn. The data part is pushed by `addStepSeparator` in
* `streaming-state.ts` whenever a `start-step` SSE event arrives after
* the message already has non-step content.
*
* Today the backend emits one `start-step` / `finish-step` pair per turn,
* so most messages won't contain a separator. The renderer is wired up so
* the planned per-model-step refactor (A2 follow-up) can light up without
* touching the persistence path.
*/
function StepSeparatorDataRenderer() {
return (
<div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2">
<div className="border-t border-border/60" />
</div>
);
}
export const StepSeparatorDataUI = makeAssistantDataUI({
name: "step-separator",
render: StepSeparatorDataRenderer,
});

View file

@ -39,12 +39,10 @@ import {
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import {
mentionedDocumentsAtom,
sidebarSelectedDocumentsAtom,
} from "@/atoms/chat/mentioned-documents.atom";
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { documentsSidebarOpenAtom } from "@/atoms/documents/ui.atoms";
import { membersAtom } from "@/atoms/members/members-query.atoms";
import {
globalNewLLMConfigsAtom,
@ -84,6 +82,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import {
CONNECTOR_ICON_TO_TYPES,
CONNECTOR_TOOL_ICON_PATHS,
getToolDisplayName,
getToolIcon,
} from "@/contracts/enums/toolIcons";
import type { Document } from "@/contracts/types/document.types";
@ -91,6 +90,7 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsSync } from "@/hooks/use-comments-sync";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture";
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
import { cn } from "@/lib/utils";
@ -364,12 +364,14 @@ const ClipboardChip: FC<{ text: string; onDismiss: () => void }> = ({ text, onDi
const Composer: FC = () => {
// Document mention state (atoms persist across component remounts)
const [mentionedDocuments, setMentionedDocuments] = useAtom(mentionedDocumentsAtom);
const setSidebarDocs = useSetAtom(sidebarSelectedDocumentsAtom);
const [showDocumentPopover, setShowDocumentPopover] = useState(false);
const [showPromptPicker, setShowPromptPicker] = useState(false);
const [mentionQuery, setMentionQuery] = useState("");
const [actionQuery, setActionQuery] = useState("");
const editorRef = useRef<InlineMentionEditorRef>(null);
const prevMentionedDocsRef = useRef<
Map<string, Pick<Document, "id" | "title" | "document_type">>
>(new Map());
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
const promptPickerRef = useRef<PromptPickerRef>(null);
const viewportRef = useRef<Element | null>(null);
@ -605,7 +607,6 @@ const Composer: FC = () => {
aui.composer().send();
editorRef.current?.clear();
setMentionedDocuments([]);
setSidebarDocs([]);
// With turnAnchor="top", ViewportSlack adds min-height to the last
// assistant message so that scrolling-to-bottom actually positions the
@ -652,43 +653,71 @@ const Composer: FC = () => {
clipboardInitialText,
aui,
setMentionedDocuments,
setSidebarDocs,
threadViewportStore,
]);
const handleDocumentRemove = useCallback(
(docId: number, docType?: string) => {
setMentionedDocuments((prev) =>
prev.filter((doc) => !(doc.id === docId && doc.document_type === docType))
);
setMentionedDocuments((prev) => {
if (!docType) {
// Defensive fallback: keep UI in sync even when chip type is unavailable.
return prev.filter((doc) => doc.id !== docId);
}
const removedKey = getMentionDocKey({ id: docId, document_type: docType });
return prev.filter((doc) => getMentionDocKey(doc) !== removedKey);
});
},
[setMentionedDocuments]
);
const handleDocumentsMention = useCallback(
(documents: Pick<Document, "id" | "title" | "document_type">[]) => {
const existingKeys = new Set(mentionedDocuments.map((d) => `${d.document_type}:${d.id}`));
const newDocs = documents.filter(
(doc) => !existingKeys.has(`${doc.document_type}:${doc.id}`)
);
const editorMentionedDocs = editorRef.current?.getMentionedDocuments() ?? [];
const editorDocKeys = new Set(editorMentionedDocs.map((doc) => getMentionDocKey(doc)));
for (const doc of newDocs) {
for (const doc of documents) {
const key = getMentionDocKey(doc);
if (editorDocKeys.has(key)) continue;
editorRef.current?.insertDocumentChip(doc);
}
setMentionedDocuments((prev) => {
const existingKeySet = new Set(prev.map((d) => `${d.document_type}:${d.id}`));
const uniqueNewDocs = documents.filter(
(doc) => !existingKeySet.has(`${doc.document_type}:${doc.id}`)
);
const existingKeySet = new Set(prev.map((d) => getMentionDocKey(d)));
const uniqueNewDocs = documents.filter((doc) => !existingKeySet.has(getMentionDocKey(doc)));
return [...prev, ...uniqueNewDocs];
});
setMentionQuery("");
},
[mentionedDocuments, setMentionedDocuments]
[setMentionedDocuments]
);
useEffect(() => {
const editor = editorRef.current;
const nextDocsMap = new Map(mentionedDocuments.map((doc) => [getMentionDocKey(doc), doc]));
const prevDocsMap = prevMentionedDocsRef.current;
if (!editor) {
prevMentionedDocsRef.current = nextDocsMap;
return;
}
const editorKeys = new Set(editor.getMentionedDocuments().map(getMentionDocKey));
for (const [key, doc] of nextDocsMap) {
if (prevDocsMap.has(key) || editorKeys.has(key)) continue;
editor.insertDocumentChip(doc, { removeTriggerText: false });
}
for (const [key, doc] of prevDocsMap) {
if (!nextDocsMap.has(key)) {
editor.removeDocumentChip(doc.id, doc.document_type);
}
}
prevMentionedDocsRef.current = nextDocsMap;
}, [mentionedDocuments]);
return (
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col gap-2">
<ChatSessionStatus
@ -767,8 +796,6 @@ interface ComposerActionProps {
const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false }) => {
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
const sidebarDocs = useAtomValue(sidebarSelectedDocumentsAtom);
const setDocumentsSidebarOpen = useSetAtom(documentsSidebarOpenAtom);
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
const [toolsPopoverOpen, setToolsPopoverOpen] = useState(false);
const isDesktop = useMediaQuery("(min-width: 640px)");
@ -1226,15 +1253,6 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
</AnimatePresence>
</button>
)}
{sidebarDocs.length > 0 && (
<button
type="button"
onClick={() => setDocumentsSidebarOpen(true)}
className="rounded-full border border-border/60 bg-accent/50 px-2.5 py-1 text-xs font-medium text-foreground/80 transition-colors hover:bg-accent"
>
{sidebarDocs.length} {sidebarDocs.length === 1 ? "source" : "sources"} selected
</button>
)}
</div>
{!hasModelConfigured && (
<div className="flex items-center gap-1.5 text-amber-600 dark:text-amber-400 text-xs">
@ -1300,12 +1318,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
);
};
/** Convert snake_case tool names to human-readable labels */
/**
* Friendly tool name for display in the chat UI. Delegates to the
* shared map in ``contracts/enums/toolIcons`` so unix-style identifiers
* (``rm``, ``ls``, ``grep`` ) and snake_cased function names render as
* plain English (e.g. "Delete file", "List files", "Search in files").
*/
function formatToolName(name: string): string {
return name
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(" ");
return getToolDisplayName(name);
}
interface ToolGroup {

View file

@ -1,30 +1,288 @@
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react";
import {
type ToolCallMessagePartComponent,
useAuiState,
} from "@assistant-ui/react";
import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import {
DoomLoopApprovalToolUI,
isDoomLoopInterrupt,
} from "@/components/tool-ui/doom-loop-approval";
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
import { getToolIcon } from "@/contracts/enums/toolIcons";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card } from "@/components/ui/card";
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import {
markActionRevertedInCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { isInterruptResult } from "@/lib/hitl";
import { cn } from "@/lib/utils";
function formatToolName(name: string): string {
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
/**
* Inline Revert button rendered on a tool card when the matching
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
*
* Reads from the unified ``useAgentActionsQuery`` cache the SAME
* react-query cache the agent-actions sheet consumes. SSE events
* (``data-action-log`` / ``data-action-log-updated``) and
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
* cache via ``setQueryData`` helpers, so the card and the sheet stay
* in lockstep on every code path: page reload, navigation, live
* stream, post-stream reversibility flip, and explicit revert clicks.
*
* Match key (in priority order):
* 1. ``a.tool_call_id === toolCallId`` direct hit in parity_v2 when
* the model streamed ``tool_call_chunks`` so the card's synthetic
* id IS the LangChain id.
* 2. ``a.tool_call_id === langchainToolCallId`` legacy mode (or
* parity_v2 with provider-side chunk emission) where the card's
* synthetic id is ``call_<run_id>`` and the LangChain id is
* backfilled onto the part by ``tool-output-available``.
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` fallback
* for cards whose synthetic id is ``call_<run_id>`` AND whose
* ``langchainToolCallId`` never got backfilled (provider emitted
* the tool_call as a single payload with no chunks AND streaming
* pre-dated the ``tool-output-available langchainToolCallId``
* backfill, e.g. older threads). Reads the parent message's
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
* match position-by-tool-name within the turn against the
* action_log rows the server returned in ``created_at`` order.
*/
function ToolCardRevertButton({
toolCallId,
toolName,
langchainToolCallId,
}: {
toolCallId: string;
toolName: string;
langchainToolCallId?: string;
}) {
const session = useAtomValue(chatSessionStateAtom);
const threadId = session?.threadId ?? null;
const queryClient = useQueryClient();
const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
// Parent message metadata, read via the narrowest possible
// selectors so this card doesn't re-render on every text-delta of
// every other part in the same message during streaming.
//
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
// returned slice's identity changes. Returning ``message?.content``
// (an array) would re-render on every token because the runtime
// rebuilds the parts array. Returning a PRIMITIVE (the position
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
// when the position hasn't actually moved — which is the common
// case during text streaming, when only ``text``/``reasoning``
// parts are mutating and the same-toolName tool-call ordering is
// stable. (See Vercel React rule ``rerender-defer-reads``.)
const chatTurnId = useAuiState(({ message }) => {
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
return meta?.custom?.chatTurnId ?? null;
});
const positionInTurn = useAuiState(({ message }) => {
const content = message?.content;
if (!Array.isArray(content)) return -1;
let n = -1;
for (const part of content) {
if (
part &&
typeof part === "object" &&
(part as { type?: string }).type === "tool-call" &&
(part as { toolName?: string }).toolName === toolName
) {
n += 1;
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
}
}
return -1;
});
const action = useMemo(() => {
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
// ~all parity_v2 streams and any legacy stream that backfilled
// ``langchainToolCallId`` via ``tool-output-available``.
const direct =
findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
if (direct) return direct;
// Tier 3: position-within-turn fallback. Only kicks in when the
// card has a synthetic ``call_<run_id>`` id AND no
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
// was emitted as a single non-chunked payload AND streaming
// pre-dated the on_tool_end backfill.
if (!chatTurnId || positionInTurn < 0) return null;
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
return turnSameTool[positionInTurn] ?? null;
}, [
findByToolCallId,
findByChatTurnAndTool,
toolCallId,
langchainToolCallId,
chatTurnId,
toolName,
positionInTurn,
]);
const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false);
if (!action) return null;
if (!action.reversible) return null;
if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
return null;
if (action.is_revert_action) return null;
if (action.error !== null && action.error !== undefined) return null;
if (!threadId) return null;
const handleRevert = async () => {
setIsReverting(true);
try {
const response = await agentActionsApiService.revert(threadId, action.id);
markActionRevertedInCache(
queryClient,
threadId,
action.id,
response.new_action_id ?? null
);
toast.success(response.message || "Action reverted.");
} catch (err) {
// 503 means revert is gated off on this deployment — hide the
// button silently rather than nagging the user. Any other error
// is surfaced as a toast so the operator can investigate.
if (err instanceof AppError && err.status === 503) {
return;
}
const message =
err instanceof AppError
? err.message
: err instanceof Error
? err.message
: "Failed to revert action.";
toast.error(message);
} finally {
setIsReverting(false);
setConfirmOpen(false);
}
};
return (
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
<AlertDialogTrigger asChild>
<Button
size="sm"
variant="outline"
className="gap-1.5"
onClick={(e) => {
e.stopPropagation();
setConfirmOpen(true);
}}
disabled={isReverting}
>
{isReverting ? (
// Spinner's typed props don't accept ``data-icon`` and
// it renders an <output>, not an <svg>, so Button's
// auto-sizing rule doesn't apply. Bare spinner +
// Button's gap handle layout.
<Spinner size="xs" />
) : (
<RotateCcw data-icon="inline-start" />
)}
Revert
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
<AlertDialogDescription>
This will undo{" "}
<span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
new entry to the history. Your chat is preserved only the changes the agent made to
your knowledge base or connected apps will be rolled back where possible.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={(e) => {
e.preventDefault();
handleRevert();
}}
disabled={isReverting}
className="gap-1.5"
>
{isReverting && <Spinner size="xs" />}
Revert
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
);
}
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
toolName,
argsText,
result,
status,
}) => {
const [isExpanded, setIsExpanded] = useState(false);
/**
* Compact tool-call card.
*
* shadcn composition note: we intentionally use ``Card`` as a visual
* frame WITHOUT ``CardHeader / CardContent``. The full composition's
* ``p-6`` padding doesn't fit a compact collapsible header that IS the
* trigger; using ``Card`` alone preserves the rounded border, shadow,
* and ``bg-card`` token (semantic colors) without forcing a layout
* that doesn't fit. All status colors use semantic tokens no manual
* dark-mode overrides, no raw hex.
*/
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
const { toolCallId, toolName, argsText, result, status } = props;
// ``langchainToolCallId`` is a SurfSense-specific extension the
// streaming pipeline attaches to the tool-call content part so
// the Revert button can resolve its ``AgentActionLog`` row even
// when only the LC id is known. assistant-ui's
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
// spreads ``{...part}`` so the prop reaches us at runtime.
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
const isError = status?.type === "incomplete" && status.reason === "error";
const isRunning = status?.type === "running" || status?.type === "requires-action";
/*
Per-card expansion state. Initial value is ``isRunning`` so a
card streaming in mounts already-expanded (no flash of
collapsed expanded on first paint), while a card loaded from
history (status="complete") mounts collapsed. The useEffect
below keeps this in lockstep with this card's own ``isRunning``
when it transitions: false true auto-expands (e.g. a tool
that re-runs after edit), true false auto-collapses once the
tool finishes. Because the dep is per-card ``isRunning`` and
not the chat-level streaming flag, sibling cards on the same
assistant turn each manage their own expansion independently.
Once ``isRunning`` is false the user controls expansion via
``onOpenChange``.
*/
const [isExpanded, setIsExpanded] = useState(isRunning);
useEffect(() => {
setIsExpanded(isRunning);
}, [isRunning]);
const errorData = status?.type === "incomplete" ? status.error : undefined;
const serializedError = useMemo(
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
@ -50,105 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
: serializedError
: null;
const Icon = getToolIcon(toolName);
const displayName = formatToolName(toolName);
const displayName = getToolDisplayName(toolName);
const subtitle = errorReason ?? cancelledReason;
return (
<div
<Card
className={cn(
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none",
"my-4 max-w-lg overflow-hidden",
isCancelled && "opacity-60",
isError && "border-destructive/20 bg-destructive/5"
isError && "border-destructive/30"
)}
>
<button
type="button"
onClick={() => setIsExpanded((prev) => !prev)}
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none"
{/*
``group`` lets the chevron (rendered as a sibling of the
main trigger button) read the Collapsible Root's
``data-[state=open]`` for rotation. The Collapsible is
fully controlled via ``isExpanded`` the useEffect
above syncs it to ``isRunning`` so the card auto-opens
while a tool streams in and auto-collapses once it
finishes. We deliberately DON'T pass ``disabled`` so
both triggers stay clickable; ``onOpenChange`` is wired
to a setter that no-ops while ``isRunning`` (see
``handleOpenChange`` below) which keeps the card pinned
open mid-stream without losing keyboard / pointer
affordance the moment streaming ends.
*/}
<Collapsible
className="group"
open={isExpanded}
onOpenChange={(next) => {
// Block manual collapse while the tool is still
// streaming — otherwise a stray click on either
// trigger would close the card and hide the live
// ``argsText`` panel mid-run. After streaming the
// user has full control again.
if (isRunning) return;
setIsExpanded(next);
}}
>
<div
className={cn(
"flex size-8 shrink-0 items-center justify-center rounded-lg",
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
)}
>
{isError ? (
<XCircleIcon className="size-4 text-destructive" />
) : isCancelled ? (
<XCircleIcon className="size-4 text-muted-foreground" />
) : isRunning ? (
<Icon className="size-4 text-primary animate-pulse" />
) : (
<CheckIcon className="size-4 text-primary" />
)}
</div>
{/*
Header row: main trigger on the left (icon + title
col), Revert + chevron-trigger on the right as
siblings of the main trigger. The chevron is wrapped
in its OWN ``CollapsibleTrigger`` (Radix supports
multiple triggers per Root) so clicking the chevron
toggles the same state as clicking the title row.
The Revert button stays a separate AlertDialog
trigger and stops propagation in its onClick so it
doesn't toggle the collapsible while opening the
confirm dialog. Keeping these as flat siblings
rather than nesting Revert / chevron inside the
title trigger avoids invalid HTML
(button-in-button) and lets the Revert button
render in BOTH the collapsed and expanded states.
*/}
<div className="flex items-stretch transition-colors hover:bg-muted/50">
<CollapsibleTrigger asChild>
<button
type="button"
className={cn(
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
// Inset ring — Card's ``overflow-hidden`` would
// clip an ``offset-2`` ring; ``ring-inset``
// paints inside the button box.
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<div
className={cn(
"flex size-8 shrink-0 items-center justify-center rounded-lg",
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
)}
>
{isError ? (
<XCircleIcon className="size-4 text-destructive" />
) : isCancelled ? (
<XCircleIcon className="size-4 text-muted-foreground" />
) : isRunning ? (
<Spinner size="sm" className="text-primary" />
) : (
<CheckIcon className="size-4 text-primary" />
)}
</div>
<div className="flex-1 min-w-0">
<p
className={cn(
"text-sm font-semibold",
isError
? "text-destructive"
: isCancelled
? "text-muted-foreground line-through"
: "text-foreground"
)}
>
{isRunning
? displayName
: isCancelled
? `Cancelled: ${displayName}`
: isError
? `Failed: ${displayName}`
: displayName}
</p>
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>}
{cancelledReason && (
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p>
)}
{errorReason && (
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p>
)}
</div>
<div className="flex flex-1 min-w-0 flex-col gap-0.5">
<div className="flex items-center gap-2">
<p
className={cn(
"text-sm font-semibold truncate",
isCancelled && "text-muted-foreground line-through",
isError && "text-destructive"
)}
>
{displayName}
</p>
{isRunning && <Badge variant="secondary">Running</Badge>}
{isError && <Badge variant="destructive">Failed</Badge>}
{isCancelled && <Badge variant="outline">Cancelled</Badge>}
</div>
{subtitle && (
<p
className={cn(
"text-xs truncate",
isError ? "text-destructive/80" : "text-muted-foreground"
)}
>
{subtitle}
</p>
)}
</div>
</button>
</CollapsibleTrigger>
{!isRunning && (
<div className="shrink-0 text-muted-foreground">
{isExpanded ? (
<ChevronDownIcon className="size-4" />
) : (
<ChevronUpIcon className="size-4" />
)}
{/*
Right-side controls. The Revert button is
visible whenever the matching action is
reversible including the collapsed state
but ``ToolCardRevertButton`` itself returns
``null`` while a tool is still running because
no action-log row exists yet, so it doesn't
need an explicit ``isRunning`` gate here.
*/}
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
<ToolCardRevertButton
toolCallId={toolCallId}
toolName={toolName}
langchainToolCallId={langchainToolCallId}
/>
<CollapsibleTrigger asChild>
<button
type="button"
aria-label={isExpanded ? "Collapse details" : "Expand details"}
className={cn(
"flex size-7 shrink-0 items-center justify-center rounded-md",
"text-muted-foreground hover:bg-muted hover:text-foreground",
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<ChevronDownIcon
className={cn(
"size-4 transition-transform duration-200",
"group-data-[state=open]:rotate-180"
)}
/>
</button>
</CollapsibleTrigger>
</div>
)}
</button>
</div>
{isExpanded && !isRunning && (
<>
<div className="mx-5 h-px bg-border/50" />
<div className="px-5 py-3 space-y-3">
{argsText && (
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p>
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
{argsText}
</pre>
{/*
CollapsibleContent body auto-open while streaming
(see ``open`` prop above) so the live ``argsText``
streams into the Inputs panel directly, no need for
a separate "Live input" panel. Native
``overflow-auto`` instead of ``ScrollArea`` because
Radix's Viewport can let content bleed past
``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
the column wrappers guarantees ``break-all`` wraps
correctly within the bounded ``max-w-lg`` Card.
*/}
<CollapsibleContent>
<Separator />
<div className="flex flex-col gap-3 px-5 py-3">
{(argsText || isRunning) && (
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
{argsText ? (
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{argsText}
</pre>
) : (
// Bridges the brief gap between
// ``tool-input-start`` (creates the
// card, ``argsText`` undefined) and
// the first ``tool-input-delta``.
<p className="px-3 py-2 text-xs italic text-muted-foreground">
Waiting for input
</p>
)}
</div>
</div>
)}
{!isCancelled && result !== undefined && (
<>
<div className="h-px bg-border/30" />
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p>
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
{typeof result === "string" ? result : serializedResult}
</pre>
<Separator />
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Result</p>
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{typeof result === "string" ? result : serializedResult}
</pre>
</div>
</div>
</>
)}
</div>
</>
)}
</div>
</CollapsibleContent>
</Collapsible>
</Card>
);
};

View file

@ -1,11 +1,12 @@
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import { CheckIcon, CopyIcon, FileText, Pencil } from "lucide-react";
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
import Image from "next/image";
import { type FC, useState } from "react";
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
interface AuthorMetadata {
displayName: string | null;
@ -48,6 +49,19 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
export const UserMessage: FC = () => {
const messageId = useAuiState(({ message }) => message?.id);
const messageText = useAuiState(({ message }) =>
(message?.content ?? [])
.map((part) =>
typeof part === "object" &&
part !== null &&
"type" in part &&
(part as { type?: string }).type === "text" &&
"text" in part
? String((part as { text?: string }).text ?? "")
: ""
)
.join("")
);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const metadata = useAuiState(({ message }) => message?.metadata);
@ -63,22 +77,12 @@ export const UserMessage: FC = () => {
<div className="col-start-2 min-w-0">
<div className="aui-user-message-content-wrapper flex items-end gap-2">
<div className="relative flex-1 min-w-0">
{mentionedDocs && mentionedDocs.length > 0 && (
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
{mentionedDocs?.map((doc) => (
<span
key={`${doc.document_type}:${doc.id}`}
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
title={doc.title}
>
<FileText className="size-3" />
<span className="max-w-[150px] truncate">{doc.title}</span>
</span>
))}
</div>
)}
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
{mentionedDocs && mentionedDocs.length > 0 ? (
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
) : (
<MessagePrimitive.Parts />
)}
</div>
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
<UserActionBar />
@ -95,6 +99,64 @@ export const UserMessage: FC = () => {
);
};
const UserMessageWithMentionChips: FC<{
text: string;
mentionedDocs: { id: number; title: string; document_type: string }[];
}> = ({ text, mentionedDocs }) => {
type Segment =
| { type: "text"; value: string; start: number }
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
const tokens = mentionedDocs
.map((doc) => ({ doc, token: `@${doc.title}` }))
.sort((a, b) => b.token.length - a.token.length);
const segments: Segment[] = [];
let i = 0;
let buffer = "";
let bufferStart = 0;
while (i < text.length) {
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
if (tokenMatch) {
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
buffer = "";
}
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
i += tokenMatch.token.length;
bufferStart = i;
continue;
}
if (!buffer) bufferStart = i;
buffer += text[i];
i += 1;
}
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
}
return (
<span className="whitespace-pre-wrap break-words">
{segments.map((segment) =>
segment.type === "text" ? (
<span key={`txt-${segment.start}`}>{segment.value}</span>
) : (
<span
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
title={segment.doc.title}
>
<span className="flex items-center text-muted-foreground">
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
</span>
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
</span>
)
)}
</span>
);
};
const UserActionBar: FC = () => {
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);

View file

@ -7,6 +7,7 @@ import { DndProvider } from "react-dnd";
import { HTML5Backend } from "react-dnd-html5-backend";
import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { DocumentNode, type DocumentNodeDoc } from "./DocumentNode";
import { type FolderDisplay, FolderNode } from "./FolderNode";
@ -17,7 +18,7 @@ interface FolderTreeViewProps {
documents: DocumentNodeDoc[];
expandedIds: Set<number>;
onToggleExpand: (folderId: number) => void;
mentionedDocIds: Set<number>;
mentionedDocKeys: Set<string>;
onToggleChatMention: (
doc: { id: number; title: string; document_type: string },
isMentioned: boolean
@ -62,7 +63,7 @@ export function FolderTreeView({
documents,
expandedIds,
onToggleExpand,
mentionedDocIds,
mentionedDocKeys,
onToggleChatMention,
onToggleFolderSelect,
onRenameFolder,
@ -181,7 +182,7 @@ export function FolderTreeView({
function compute(folderId: number): { selected: number; total: number } {
const directDocs = (docsByFolder[folderId] ?? []).filter(isSelectable);
let selected = directDocs.filter((d) => mentionedDocIds.has(d.id)).length;
let selected = directDocs.filter((d) => mentionedDocKeys.has(getMentionDocKey(d))).length;
let total = directDocs.length;
for (const child of foldersByParent[folderId] ?? []) {
@ -202,7 +203,7 @@ export function FolderTreeView({
if (states[f.id] === undefined) compute(f.id);
}
return states;
}, [folders, docsByFolder, foldersByParent, mentionedDocIds]);
}, [folders, docsByFolder, foldersByParent, mentionedDocKeys]);
const folderMap = useMemo(() => {
const map: Record<number, FolderDisplay> = {};
@ -276,7 +277,7 @@ export function FolderTreeView({
key={`doc-${d.id}`}
doc={d}
depth={depth}
isMentioned={mentionedDocIds.has(d.id)}
isMentioned={mentionedDocKeys.has(getMentionDocKey(d))}
onToggleChatMention={onToggleChatMention}
onPreview={onPreviewDocument}
onEdit={onEditDocument}
@ -356,7 +357,7 @@ export function FolderTreeView({
key={`doc-${d.id}`}
doc={d}
depth={depth}
isMentioned={mentionedDocIds.has(d.id)}
isMentioned={mentionedDocKeys.has(getMentionDocKey(d))}
onToggleChatMention={onToggleChatMention}
onPreview={onPreviewDocument}
onEdit={onEditDocument}

View file

@ -9,6 +9,7 @@ import {
import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile";
import { ShieldCheck } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
import {
createTokenUsageStore,
@ -17,10 +18,14 @@ import {
} from "@/components/assistant-ui/token-usage-context";
import { useAnonymousMode } from "@/contexts/anonymous-mode";
import {
addStepSeparator,
addToolCall,
appendReasoning,
appendText,
appendToolInputDelta,
buildContentForUI,
type ContentPartsState,
endReasoning,
FrameBatchedUpdater,
readSSEStream,
type ThinkingStepData,
@ -32,7 +37,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events";
import { FreeModelSelector } from "./free-model-selector";
import { FreeThread } from "./free-thread";
const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]);
// Render all tool calls via ToolFallback; backend keeps persisted
// payloads bounded by summarising / truncating outputs.
const TOOLS_WITH_UI = "all" as const;
const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? "";
/** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */
@ -125,6 +132,7 @@ export function FreeChatPage() {
const contentPartsState: ContentPartsState = {
contentParts: [],
currentTextPartIndex: -1,
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
const { toolCallIndices } = contentPartsState;
@ -139,6 +147,10 @@ export function FreeChatPage() {
);
};
const scheduleFlush = () => batcher.schedule(flushMessages);
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
try {
for await (const parsed of readSSEStream(response)) {
@ -148,29 +160,74 @@ export function FreeChatPage() {
scheduleFlush();
break;
case "tool-input-start":
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
batcher.flush();
case "reasoning-delta":
appendReasoning(contentPartsState, parsed.delta);
scheduleFlush();
break;
case "tool-input-available":
case "reasoning-end":
endReasoning(contentPartsState);
scheduleFlush();
break;
case "start-step":
addStepSeparator(contentPartsState);
scheduleFlush();
break;
case "finish-step":
break;
case "tool-input-start":
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
parsed.toolCallId,
parsed.toolName,
{},
false,
parsed.langchainToolCallId
);
forceFlush();
break;
case "tool-input-delta":
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
break;
case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId,
});
} else {
addToolCall(
contentPartsState,
TOOLS_WITH_UI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
parsed.input || {},
false,
parsed.langchainToolCallId
);
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
}
batcher.flush();
forceFlush();
break;
}
case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
batcher.flush();
updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId,
});
forceFlush();
break;
case "data-thinking-step": {
@ -369,6 +426,7 @@ export function FreeChatPage() {
<TokenUsageProvider store={tokenUsageStore}>
<AssistantRuntimeProvider runtime={runtime}>
<ThinkingStepsDataUI />
<StepSeparatorDataUI />
<div className="flex h-full flex-col overflow-hidden">
<div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4">
<FreeModelSelector />

View file

@ -23,7 +23,9 @@ import { useTranslations } from "next-intl";
import type React from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { sidebarSelectedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
import {
mentionedDocumentsAtom,
} from "@/atoms/chat/mentioned-documents.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
@ -72,6 +74,7 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { useDebouncedValue } from "@/hooks/use-debounced-value";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI, usePlatform } from "@/hooks/use-platform";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { foldersApiService } from "@/lib/apis/folders-api.service";
@ -425,8 +428,11 @@ function AuthenticatedDocumentsSidebarBase({
}, [refreshWatchedIds]);
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]);
const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom);
const mentionedDocKeys = useMemo(
() => new Set(sidebarDocs.map((d) => getMentionDocKey(d))),
[sidebarDocs]
);
// Folder state
const [expandedFolderMap, setExpandedFolderMap] = useAtom(expandedFolderIdsAtom);
@ -874,11 +880,12 @@ function AuthenticatedDocumentsSidebarBase({
const handleToggleChatMention = useCallback(
(doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => {
const key = getMentionDocKey(doc);
if (isMentioned) {
setSidebarDocs((prev) => prev.filter((d) => d.id !== doc.id));
setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key));
} else {
setSidebarDocs((prev) => {
if (prev.some((d) => d.id === doc.id)) return prev;
if (prev.some((d) => getMentionDocKey(d) === key)) return prev;
return [
...prev,
{ id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum },
@ -909,9 +916,9 @@ function AuthenticatedDocumentsSidebarBase({
if (selectAll) {
setSidebarDocs((prev) => {
const existingIds = new Set(prev.map((d) => d.id));
const existingDocKeys = new Set(prev.map((d) => getMentionDocKey(d)));
const newDocs = subtreeDocs
.filter((d) => !existingIds.has(d.id))
.filter((d) => !existingDocKeys.has(getMentionDocKey(d)))
.map((d) => ({
id: d.id,
title: d.title,
@ -920,8 +927,8 @@ function AuthenticatedDocumentsSidebarBase({
return newDocs.length > 0 ? [...prev, ...newDocs] : prev;
});
} else {
const idsToRemove = new Set(subtreeDocs.map((d) => d.id));
setSidebarDocs((prev) => prev.filter((d) => !idsToRemove.has(d.id)));
const keysToRemove = new Set(subtreeDocs.map((d) => getMentionDocKey(d)));
setSidebarDocs((prev) => prev.filter((d) => !keysToRemove.has(getMentionDocKey(d))));
}
},
[treeDocuments, foldersByParent, setSidebarDocs]
@ -1157,7 +1164,7 @@ function AuthenticatedDocumentsSidebarBase({
documents={searchFilteredDocuments}
expandedIds={expandedIds}
onToggleExpand={toggleFolderExpand}
mentionedDocIds={mentionedDocIds}
mentionedDocKeys={mentionedDocKeys}
onToggleChatMention={handleToggleChatMention}
onToggleFolderSelect={handleToggleFolderSelect}
onRenameFolder={handleRenameFolder}
@ -1585,16 +1592,20 @@ function AnonymousDocumentsSidebar({
const [isUploading, setIsUploading] = useState(false);
const [search, setSearch] = useState("");
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]);
const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom);
const mentionedDocKeys = useMemo(
() => new Set(sidebarDocs.map((d) => getMentionDocKey(d))),
[sidebarDocs]
);
const handleToggleChatMention = useCallback(
(doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => {
const key = getMentionDocKey(doc);
if (isMentioned) {
setSidebarDocs((prev) => prev.filter((d) => d.id !== doc.id));
setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key));
} else {
setSidebarDocs((prev) => {
if (prev.some((d) => d.id === doc.id)) return prev;
if (prev.some((d) => getMentionDocKey(d) === key)) return prev;
return [
...prev,
{ id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum },
@ -1814,7 +1825,7 @@ function AnonymousDocumentsSidebar({
documents={searchFilteredDocs}
expandedIds={new Set()}
onToggleExpand={() => {}}
mentionedDocIds={mentionedDocIds}
mentionedDocKeys={mentionedDocKeys}
onToggleChatMention={handleToggleChatMention}
onToggleFolderSelect={() => {}}
onRenameFolder={() => gate("rename folders")}

View file

@ -1,6 +1,7 @@
"use client";
import { AssistantRuntimeProvider } from "@assistant-ui/react";
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
import { Navbar } from "@/components/homepage/navbar";
import { ReportPanel } from "@/components/report-panel/report-panel";
@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) {
<Navbar scrolledBgClassName={navbarScrolledBg} />
<AssistantRuntimeProvider runtime={runtime}>
<ThinkingStepsDataUI />
<StepSeparatorDataUI />
<div className="flex h-screen pt-16 overflow-hidden">
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
<PublicThread footer={<PublicChatFooter shareToken={shareToken} />} />

View file

@ -13,6 +13,7 @@ import Image from "next/image";
import { type FC, type ReactNode, useState } from "react";
import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context";
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { GenerateImageToolUI } from "@/components/tool-ui/generate-image";
@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => {
<MessagePrimitive.Parts
components={{
Text: MarkdownText,
Reasoning: ReasoningMessagePart,
tools: {
by_name: {
generate_podcast: GeneratePodcastToolUI,

View file

@ -8,6 +8,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Textarea } from "@/components/ui/textarea";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import { useHitlPhase } from "@/hooks/use-hitl-phase";
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
@ -77,7 +78,7 @@ function GenericApprovalCard({
const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args);
const [isEditing, setIsEditing] = useState(false);
const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
const displayName = getToolDisplayName(toolName);
const mcpServer = interruptData.context?.mcp_server as string | undefined;
const toolDescription = interruptData.context?.tool_description as string | undefined;
@ -186,12 +187,11 @@ function GenericApprovalCard({
</>
)}
{/* Parameters */}
{Object.keys(args).length > 0 && (
<>
<div className="mx-5 h-px bg-border/50" />
<div className="px-5 py-4 space-y-2">
<p className="text-xs font-medium text-muted-foreground">Parameters</p>
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
{phase === "pending" && isEditing ? (
<ParamEditor
params={editedParams}

View file

@ -1,33 +1,223 @@
import {
BookOpen,
Brain,
Calendar,
Check,
FileEdit,
FilePlus,
FileText,
FileUser,
FileX,
Film,
FolderPlus,
FolderTree,
FolderX,
Globe,
ImageIcon,
ListTodo,
type LucideIcon,
Mail,
MessagesSquare,
Move,
Plus,
Podcast,
ScanLine,
Search,
Send,
Trash2,
Wrench,
} from "lucide-react";
/**
* Every tool now renders a card via ``ToolFallback``. The icon map is
* keyed on the canonical backend tool name (registered in
* ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown
* names fall back to the generic ``Wrench`` icon so the card still
* communicates "this is a tool call".
*/
const TOOL_ICONS: Record<string, LucideIcon> = {
// Generators
generate_podcast: Podcast,
generate_video_presentation: Film,
generate_report: FileText,
generate_resume: FileUser,
generate_image: ImageIcon,
display_image: ImageIcon,
// Web / search
scrape_webpage: ScanLine,
web_search: Globe,
search_surfsense_docs: BookOpen,
// Memory
update_memory: Brain,
// Filesystem (built-in deepagent + middleware)
read_file: FileText,
write_file: FilePlus,
edit_file: FileEdit,
move_file: Move,
rm: FileX,
rmdir: FolderX,
mkdir: FolderPlus,
ls: FolderTree,
write_todos: ListTodo,
// Calendar
search_calendar_events: Search,
create_calendar_event: Calendar,
update_calendar_event: Calendar,
delete_calendar_event: Calendar,
// Gmail
search_gmail: Search,
read_gmail_email: Mail,
create_gmail_draft: Mail,
update_gmail_draft: FileEdit,
send_gmail_email: Send,
trash_gmail_email: Trash2,
// Notion / Confluence pages
create_notion_page: FilePlus,
update_notion_page: FileEdit,
delete_notion_page: FileX,
create_confluence_page: FilePlus,
update_confluence_page: FileEdit,
delete_confluence_page: FileX,
// Linear / Jira issues
create_linear_issue: Plus,
update_linear_issue: FileEdit,
delete_linear_issue: Trash2,
create_jira_issue: Plus,
update_jira_issue: FileEdit,
delete_jira_issue: Trash2,
// Drive-like file connectors
create_google_drive_file: FilePlus,
delete_google_drive_file: FileX,
create_dropbox_file: FilePlus,
delete_dropbox_file: FileX,
create_onedrive_file: FilePlus,
delete_onedrive_file: FileX,
// Chat connectors
list_discord_channels: MessagesSquare,
read_discord_messages: MessagesSquare,
send_discord_message: Send,
list_teams_channels: MessagesSquare,
read_teams_messages: MessagesSquare,
send_teams_message: Send,
// Luma
list_luma_events: Calendar,
read_luma_event: Calendar,
create_luma_event: Calendar,
// Misc
get_connected_accounts: Check,
execute: Wrench,
execute_code: Wrench,
};
export function getToolIcon(name: string): LucideIcon {
return TOOL_ICONS[name] ?? Wrench;
}
/**
* Friendly display names for tools shown in the chat UI.
*
* Most users aren't engineers; they shouldn't see raw unix-style
* identifiers like ``rm`` / ``rmdir`` / ``ls`` / ``grep`` / ``glob`` or
* snake_cased function names. The map below renders each tool with
* plain English wording (verb + object) so non-technical users
* understand what the agent is doing at a glance.
*
* Unmapped tool names fall back to a snake_case-to-Title-Case
* conversion via :func:`getToolDisplayName`.
*/
const TOOL_DISPLAY_NAMES: Record<string, string> = {
// Filesystem / knowledge base
read_file: "Read file",
write_file: "Write file",
edit_file: "Edit file",
move_file: "Move file",
rm: "Delete file",
rmdir: "Delete folder",
mkdir: "Create folder",
ls: "List files",
glob: "Find files",
grep: "Search in files",
write_todos: "Plan tasks",
save_document: "Save document",
// Generators
generate_podcast: "Generate podcast",
generate_video_presentation: "Generate video presentation",
generate_report: "Generate report",
generate_resume: "Generate resume",
generate_image: "Generate image",
display_image: "Show image",
// Web / search
scrape_webpage: "Read webpage",
web_search: "Search the web",
search_surfsense_docs: "Search knowledge base",
// Memory
update_memory: "Update memory",
// Calendar
search_calendar_events: "Search calendar",
create_calendar_event: "Create event",
update_calendar_event: "Update event",
delete_calendar_event: "Delete event",
// Gmail
search_gmail: "Search Gmail",
read_gmail_email: "Read email",
create_gmail_draft: "Draft email",
update_gmail_draft: "Update draft",
send_gmail_email: "Send email",
trash_gmail_email: "Move email to trash",
// Notion
create_notion_page: "Create Notion page",
update_notion_page: "Update Notion page",
delete_notion_page: "Delete Notion page",
// Confluence
create_confluence_page: "Create Confluence page",
update_confluence_page: "Update Confluence page",
delete_confluence_page: "Delete Confluence page",
// Linear
create_linear_issue: "Create Linear issue",
update_linear_issue: "Update Linear issue",
delete_linear_issue: "Delete Linear issue",
// Jira
create_jira_issue: "Create Jira issue",
update_jira_issue: "Update Jira issue",
delete_jira_issue: "Delete Jira issue",
// Drive-like file connectors
create_google_drive_file: "Create Google Drive file",
delete_google_drive_file: "Delete Google Drive file",
create_dropbox_file: "Create Dropbox file",
delete_dropbox_file: "Delete Dropbox file",
create_onedrive_file: "Create OneDrive file",
delete_onedrive_file: "Delete OneDrive file",
// Discord
list_discord_channels: "List Discord channels",
read_discord_messages: "Read Discord messages",
send_discord_message: "Send Discord message",
// Teams
list_teams_channels: "List Teams channels",
read_teams_messages: "Read Teams messages",
send_teams_message: "Send Teams message",
// Luma
list_luma_events: "List Luma events",
read_luma_event: "Read Luma event",
create_luma_event: "Create Luma event",
// Misc
get_connected_accounts: "Check connected accounts",
execute: "Run command",
execute_code: "Run code",
};
/**
* Format a tool's canonical (snake_case) name for display in the chat UI.
*
* Looks up :data:`TOOL_DISPLAY_NAMES` first; falls back to a
* snake_case-to-Title-Case rewrite for tools that don't have a curated
* label (e.g. dynamically registered MCP tools).
*/
export function getToolDisplayName(name: string): string {
const friendly = TOOL_DISPLAY_NAMES[name];
if (friendly) return friendly;
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
}
export const CONNECTOR_TOOL_ICON_PATHS: Record<string, { src: string; alt: string }> = {
gmail: { src: "/connectors/google-gmail.svg", alt: "Gmail" },
google_calendar: { src: "/connectors/google-calendar.svg", alt: "Google Calendar" },

View file

@ -1,7 +1,13 @@
import { z } from "zod";
/**
* Raw message from database (real-time sync)
* Raw message from database (real-time sync).
*
* ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``)
* can populate ``metadata.custom.chatTurnId`` on the
* ``ThreadMessageLike`` even after the live-collab Zero re-sync. The
* inline Revert button's ``(chat_turn_id, tool_name, position)``
* fallback in tool-fallback.tsx depends on it.
*/
export const rawMessage = z.object({
id: z.number(),
@ -10,6 +16,7 @@ export const rawMessage = z.object({
content: z.unknown(),
author_id: z.string().nullable(),
created_at: z.string(),
turn_id: z.string().nullable().optional(),
});
export type RawMessage = z.infer<typeof rawMessage>;

View file

@ -0,0 +1,416 @@
"use client";
import { type QueryClient, useQuery } from "@tanstack/react-query";
import { useCallback, useEffect, useMemo, useRef } from "react";
import {
type AgentAction,
type AgentActionListResponse,
agentActionsApiService,
} from "@/lib/apis/agent-actions-api.service";
// =============================================================================
// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug``
// to ``true`` to trace the full SSE → cache → card → button pipeline in
// the browser console. Off by default so we don't spam production. The
// infrastructure stays in place because the underlying id-mismatch
// failure mode is rare-but-real and surfaces only at runtime.
// =============================================================================
const RevertDebug = false;
const dbg = (...args: unknown[]) => {
if (RevertDebug && typeof window !== "undefined") {
// eslint-disable-next-line no-console
console.log("[RevertDebug]", ...args);
}
};
/**
* Unified store for ``AgentActionLog`` rows scoped to one thread.
*
* Replaces the previous SSE side-channel atom mess
* (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` /
* ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook.
* One react-query cache entry is now the single source of truth for:
*
* * the inline Revert button on every tool-call card
* * the per-turn "Revert turn" button under each assistant message
* * the edit-from-position pre-flight that decides whether to show
* the confirmation dialog
* * the agent-actions sheet
*
* The cache is hydrated by ``GET /threads/{id}/actions`` (sized to
* 200, the server max) and updated incrementally by helpers that turn
* SSE events / revert RPC responses into ``setQueryData`` mutations.
* That keeps the card and the sheet in lockstep on every code path
* page reload, navigation, live stream, post-stream reversibility flip,
* and explicit revert clicks.
*/
export const ACTION_LOG_PAGE_SIZE = 200;
/** Stable react-query key for the per-thread action list. */
export function agentActionsQueryKey(threadId: number | null) {
return threadId !== null
? (["agent-actions", threadId] as const)
: (["agent-actions", "none"] as const);
}
/** Subset of the SSE ``data-action-log`` payload we care about. */
export interface ActionLogSseEvent {
id: number;
lc_tool_call_id: string | null;
chat_turn_id: string | null;
tool_name: string;
reversible: boolean;
reverse_descriptor_present: boolean;
error: boolean;
created_at: string | null;
}
/**
* Append or upsert a freshly-emitted ``AgentActionLog`` row into the
* thread-scoped query cache.
*
* The SSE payload is a strict subset of ``AgentAction``; missing
* fields (``args``, ``reverse_descriptor``, ``user_id``) are filled
* with ``null`` placeholders. The next refetch (sheet open, user
* focus, route stale) backfills them but the inline Revert button
* only reads the fields the SSE payload carries, so it lights up
* immediately.
*/
export function applyActionLogSse(
queryClient: QueryClient,
threadId: number,
searchSpaceId: number,
event: ActionLogSseEvent
): void {
dbg("applyActionLogSse: incoming SSE event", {
threadId,
searchSpaceId,
event,
});
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
const placeholder: AgentAction = {
id: event.id,
thread_id: threadId,
user_id: null,
search_space_id: searchSpaceId,
tool_name: event.tool_name,
args: null,
result_id: null,
reversible: event.reversible,
reverse_descriptor: event.reverse_descriptor_present ? {} : null,
error: event.error ? {} : null,
reverse_of: null,
reverted_by_action_id: null,
is_revert_action: false,
tool_call_id: event.lc_tool_call_id,
chat_turn_id: event.chat_turn_id,
created_at: event.created_at ?? new Date().toISOString(),
};
if (!prev) {
return {
items: [placeholder],
total: 1,
page: 0,
page_size: ACTION_LOG_PAGE_SIZE,
has_more: false,
};
}
const existingIdx = prev.items.findIndex((a) => a.id === event.id);
if (existingIdx >= 0) {
const merged = [...prev.items];
const existing = merged[existingIdx];
if (existing) {
merged[existingIdx] = {
...existing,
reversible: event.reversible,
tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id,
chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id,
};
}
dbg("applyActionLogSse: merged into existing entry", {
id: event.id,
tool_call_id: merged[existingIdx]?.tool_call_id,
reversible: merged[existingIdx]?.reversible,
});
return { ...prev, items: merged };
}
dbg("applyActionLogSse: appended new placeholder", {
id: event.id,
tool_call_id: placeholder.tool_call_id,
tool_name: placeholder.tool_name,
reversible: placeholder.reversible,
cacheSizeAfter: prev.items.length + 1,
});
// REST returns newest-first — keep that ordering when
// the server eventually refetches by prepending.
return {
...prev,
items: [placeholder, ...prev.items],
total: prev.total + 1,
};
}
);
}
/**
* Apply a post-SAVEPOINT reversibility flip
* (``data-action-log-updated`` SSE event) to the cache.
*/
export function applyActionLogUpdatedSse(
queryClient: QueryClient,
threadId: number,
id: number,
reversible: boolean
): void {
dbg("applyActionLogUpdatedSse: reversibility flip", {
threadId,
id,
reversible,
});
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) {
dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", {
threadId,
id,
});
return prev;
}
let mutated = false;
const items = prev.items.map((a) => {
if (a.id !== id) return a;
mutated = true;
return { ...a, reversible };
});
if (!mutated) {
dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", {
threadId,
id,
cacheSize: prev.items.length,
cacheIds: prev.items.map((a) => a.id),
});
}
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Optimistically mark ``id`` as reverted.
*
* Used by the inline / per-turn Revert button immediately after the
* server returns success so the UI flips to "Reverted" without
* waiting for a refetch. ``newActionId`` is the id of the new
* ``is_revert_action`` row the server inserted; pass ``null`` if the
* server didn't return it.
*/
export function markActionRevertedInCache(
queryClient: QueryClient,
threadId: number,
id: number,
newActionId: number | null
): void {
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) return prev;
let mutated = false;
const items = prev.items.map((a) => {
if (a.id !== id) return a;
mutated = true;
// ``-1`` is a sentinel meaning "we know it was reverted
// but the server didn't tell us the new row's id".
return {
...a,
reverted_by_action_id: newActionId ?? -1,
};
});
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Apply a batch of revert results (per-turn revert response) to the
* cache. Anything in the ``reverted`` / ``already_reverted`` buckets
* gets its ``reverted_by_action_id`` set; other rows are left alone.
*/
export function applyRevertTurnResultsToCache(
queryClient: QueryClient,
threadId: number,
entries: Array<{ id: number; newActionId: number | null }>
): void {
if (entries.length === 0) return;
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) return prev;
const lookup = new Map(entries.map((e) => [e.id, e.newActionId]));
let mutated = false;
const items = prev.items.map((a) => {
if (!lookup.has(a.id)) return a;
mutated = true;
const newActionId = lookup.get(a.id) ?? null;
return { ...a, reverted_by_action_id: newActionId ?? -1 };
});
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Read-side hook used by the card, the turn button, the sheet, and
* the edit-from-position pre-flight.
*
* Returns the raw query state plus convenience selectors so consumers
* don't reach into ``data.items`` directly. ``enabled`` is the only
* knob pass ``false`` to keep the query dormant when the consumer
* doesn't yet have a thread id.
*/
export function useAgentActionsQuery(
threadId: number | null,
options: { enabled?: boolean } = {}
) {
const enabled = (options.enabled ?? true) && threadId !== null;
const query = useQuery({
queryKey: agentActionsQueryKey(threadId),
queryFn: async () => {
dbg("useAgentActionsQuery: REST fetch START", {
threadId,
pageSize: ACTION_LOG_PAGE_SIZE,
});
const res = await agentActionsApiService.listForThread(threadId as number, {
page: 0,
pageSize: ACTION_LOG_PAGE_SIZE,
});
dbg("useAgentActionsQuery: REST fetch DONE", {
threadId,
total: res.total,
returned: res.items.length,
items: res.items.map((a) => ({
id: a.id,
tool_name: a.tool_name,
tool_call_id: a.tool_call_id,
reversible: a.reversible,
reverted_by_action_id: a.reverted_by_action_id,
is_revert_action: a.is_revert_action,
})),
});
return res;
},
enabled,
staleTime: 15 * 1000,
});
const items = useMemo(() => query.data?.items ?? [], [query.data]);
// Index ``items`` once per change so the lookups below are O(1)
// instead of O(N) per card per render. With the cache sized to 200
// rows and many tool cards visible at once, the unindexed scan was
// the hottest path on every assistant text-delta. (Vercel React
// rule ``js-index-maps`` / ``js-set-map-lookups``.)
const byToolCallId = useMemo(() => {
const m = new Map<string, AgentAction>();
for (const a of items) {
if (a.tool_call_id) m.set(a.tool_call_id, a);
}
return m;
}, [items]);
// Pre-grouped + pre-sorted (oldest-first, the order the agent
// actually executed them in) so the (chat_turn_id, tool_name,
// position) fallback in ``tool-fallback.tsx`` is also O(1) per
// card. Excludes ``is_revert_action`` rows so the position index
// matches the agent's original execution order.
const byTurnAndTool = useMemo(() => {
const m = new Map<string, AgentAction[]>();
for (const a of items) {
if (!a.chat_turn_id || a.is_revert_action) continue;
const key = `${a.chat_turn_id}::${a.tool_name}`;
const bucket = m.get(key);
if (bucket) bucket.push(a);
else m.set(key, [a]);
}
for (const bucket of m.values()) {
bucket.sort(
(a, b) =>
new Date(a.created_at).getTime() - new Date(b.created_at).getTime()
);
}
return m;
}, [items]);
// Snapshot the cache shape when its size changes — easiest way to
// spot when the cache is empty or stale at the moment a card
// mounts. Tracked on a ref so we don't re-run the diff on
// reference-equal cache reads.
const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null);
useEffect(() => {
const last = lastSnapshotRef.current;
if (!last || last.threadId !== threadId || last.size !== items.length) {
dbg("useAgentActionsQuery: cache snapshot", {
threadId,
enabled,
itemCount: items.length,
itemKeys: items.slice(0, 8).map((a) => ({
id: a.id,
tool_name: a.tool_name,
tool_call_id: a.tool_call_id,
chat_turn_id: a.chat_turn_id,
reversible: a.reversible,
})),
});
lastSnapshotRef.current = { threadId, size: items.length };
}
}, [threadId, enabled, items]);
const findByToolCallId = useCallback(
(toolCallId: string | null | undefined): AgentAction | null => {
if (!toolCallId) return null;
const found = byToolCallId.get(toolCallId) ?? null;
if (!found && items.length > 0) {
dbg("findByToolCallId: MISS", {
queriedToolCallId: toolCallId,
itemCount: items.length,
availableToolCallIds: Array.from(byToolCallId.keys()),
});
}
return found;
},
[byToolCallId, items.length]
);
const findByChatTurnId = useCallback(
(chatTurnId: string | null | undefined): AgentAction[] => {
if (!chatTurnId) return [];
// Per-turn aggregation is uncommon enough (only the
// "Revert turn" button uses it) that re-scanning is fine;
// indexing it would just bloat memory.
return items.filter((a) => a.chat_turn_id === chatTurnId);
},
[items]
);
const findByChatTurnAndTool = useCallback(
(
chatTurnId: string | null | undefined,
toolName: string | null | undefined
): AgentAction[] => {
if (!chatTurnId || !toolName) return [];
return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? [];
},
[byTurnAndTool]
);
return {
...query,
items,
findByToolCallId,
findByChatTurnId,
findByChatTurnAndTool,
};
}

View file

@ -31,6 +31,14 @@ export function useMessagesSync(
content: msg.content,
author_id: msg.authorId ?? null,
created_at: new Date(msg.createdAt).toISOString(),
// Forward the per-turn correlation id so post-stream Zero
// re-syncs preserve ``metadata.custom.chatTurnId`` on the
// converted ``ThreadMessageLike``. Without this the inline
// Revert button's ``(chat_turn_id, tool_name, position)``
// fallback breaks the moment Zero overwrites the messages
// state after a live stream completes (see
// ``handleSyncedMessagesUpdate`` in the chat page).
turn_id: msg.turnId ?? null,
}));
onMessagesUpdateRef.current(mapped);

View file

@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({
reverse_of: z.number().nullable(),
reverted_by_action_id: z.number().nullable(),
is_revert_action: z.boolean(),
// Correlation ids added in migration 135. The LangChain
// ``tool_call_id`` joins this row to the chat tool card via the
// ``data-action-log.lc_tool_call_id`` SSE event, and
// ``chat_turn_id`` keys the per-turn revert endpoint.
tool_call_id: z.string().nullable().optional(),
chat_turn_id: z.string().nullable().optional(),
created_at: z.string(),
});
@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({
export type RevertResponse = z.infer<typeof RevertResponseSchema>;
// Per-turn batch revert. The route never returns whole-batch 4xx;
// partial success is the common case and surfaced as
// ``status === "partial"`` with a per-action result list.
const RevertTurnActionResultSchema = z.object({
action_id: z.number(),
tool_name: z.string(),
status: z.enum([
"reverted",
"already_reverted",
"not_reversible",
"permission_denied",
"failed",
"skipped",
]),
message: z.string().nullable().optional(),
new_action_id: z.number().nullable().optional(),
error: z.string().nullable().optional(),
});
export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>;
const RevertTurnResponseSchema = z.object({
status: z.enum(["ok", "partial"]),
chat_turn_id: z.string(),
total: z.number(),
reverted: z.number(),
already_reverted: z.number(),
not_reversible: z.number(),
// ``permission_denied`` and ``skipped`` are first-class counters so
// ``total === reverted + already_reverted +
// not_reversible + permission_denied + failed + skipped`` always
// holds. ``.default(0)`` keeps the schema backwards-compatible
// with older deployments that haven't shipped the response model
// update yet.
permission_denied: z.number().default(0),
failed: z.number(),
skipped: z.number().default(0),
results: z.array(RevertTurnActionResultSchema),
});
export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>;
class AgentActionsApiService {
listForThread = async (
threadId: number,
@ -59,6 +107,14 @@ class AgentActionsApiService {
{ body: {} }
);
};
revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => {
return baseApiService.post(
`/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`,
RevertTurnResponseSchema,
{ body: {} }
);
};
}
export const agentActionsApiService = new AgentActionsApiService();

View file

@ -0,0 +1,8 @@
type MentionKeyInput = {
id: number;
document_type?: string | null;
};
export function getMentionDocKey(doc: MentionKeyInput): string {
return `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
}

View file

@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
}
const metadata =
msg.author_id || msg.token_usage
msg.author_id || msg.token_usage || msg.turn_id
? {
custom: {
...(msg.author_id && {
@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
},
}),
...(msg.token_usage && { usage: msg.token_usage }),
// Surface ``chat_turn_id`` so the assistant message
// footer can scope its "Revert turn" button to just
// this turn's actions. Null on legacy rows.
...(msg.turn_id && { chatTurnId: msg.turn_id }),
},
}
: undefined;

View file

@ -9,21 +9,59 @@ export interface ThinkingStepData {
export type ContentPart =
| { type: "text"; text: string }
| { type: "reasoning"; text: string }
| {
type: "tool-call";
toolCallId: string;
toolName: string;
args: Record<string, unknown>;
result?: unknown;
/**
* Live / finalized JSON text for the tool's input arguments.
*
* - During streaming: accumulated partial JSON text from
* ``tool-input-delta`` events (may be invalid JSON
* mid-stream). assistant-ui's argsText parser tolerates
* invalid JSON gracefully (changelog 0.7.32 / 0.7.78).
* - On completion (``tool-input-available``): replaced with
* ``JSON.stringify(input, null, 2)`` so the post-stream
* card renders pretty-printed JSON instead of the
* model's possibly-fragmented formatting.
*
* Per assistant-ui ``ThreadMessageLike`` precedence
* (changelog 0.11.6 ``d318c83``), when ``argsText`` is
* supplied it wins over ``JSON.stringify(args)``.
*/
argsText?: string;
/**
* Authoritative LangChain ``tool_call.id`` propagated by the backend
* via ``langchainToolCallId`` on tool-input-start/available and
* tool-output-available events. Used to join a card to the
* matching ``AgentActionLog`` row exposed by
* ``GET /threads/{id}/actions`` and the streamed
* ``data-action-log`` events.
*/
langchainToolCallId?: string;
}
| {
type: "data-thinking-steps";
data: { steps: ThinkingStepData[] };
}
| {
/**
* Between-step separator. Pushed by `addStepSeparator` when
* a `start-step` SSE event arrives AFTER the message already
* has non-step content. Rendered by `StepSeparatorDataUI`
* (see assistant-ui/step-separator.tsx).
*/
type: "data-step-separator";
data: { stepIndex: number };
};
export interface ContentPartsState {
contentParts: ContentPart[];
currentTextPartIndex: number;
currentReasoningPartIndex: number;
toolCallIndices: Map<string, number>;
}
@ -74,6 +112,9 @@ export function updateThinkingSteps(
if (state.currentTextPartIndex >= 0) {
state.currentTextPartIndex += 1;
}
if (state.currentReasoningPartIndex >= 0) {
state.currentReasoningPartIndex += 1;
}
for (const [id, idx] of state.toolCallIndices) {
state.toolCallIndices.set(id, idx + 1);
}
@ -131,6 +172,12 @@ export class FrameBatchedUpdater {
}
export function appendText(state: ContentPartsState, delta: string): void {
// First text delta after a reasoning block: close the reasoning so
// the assistant-ui renderer treats them as separate parts (the
// reasoning block collapses; the answer streams below).
if (state.currentReasoningPartIndex >= 0) {
state.currentReasoningPartIndex = -1;
}
if (
state.currentTextPartIndex >= 0 &&
state.contentParts[state.currentTextPartIndex]?.type === "text"
@ -143,39 +190,161 @@ export function appendText(state: ContentPartsState, delta: string): void {
}
}
export function appendReasoning(state: ContentPartsState, delta: string): void {
// Symmetric to appendText: open a fresh reasoning block on first
// delta, then accumulate into it. ``endReasoning`` simply closes
// the active block; subsequent reasoning deltas would open a new
// one (matching ``text-start/end`` semantics on the wire).
if (state.currentTextPartIndex >= 0) {
state.currentTextPartIndex = -1;
}
if (
state.currentReasoningPartIndex >= 0 &&
state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning"
) {
(
state.contentParts[state.currentReasoningPartIndex] as {
type: "reasoning";
text: string;
}
).text += delta;
} else {
state.contentParts.push({ type: "reasoning", text: delta });
state.currentReasoningPartIndex = state.contentParts.length - 1;
}
}
export function endReasoning(state: ContentPartsState): void {
state.currentReasoningPartIndex = -1;
}
export function addStepSeparator(state: ContentPartsState): void {
// Push a divider between consecutive model steps within a single
// assistant turn. We only emit it when the message already has
// non-step content (so the FIRST step of a turn doesn't
// generate a leading separator) and when the previous part isn't
// itself a separator (defensive against duplicate `start-step`
// events).
const hasContent = state.contentParts.some(
(p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call"
);
if (!hasContent) return;
const last = state.contentParts[state.contentParts.length - 1];
if (last && last.type === "data-step-separator") return;
const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length;
state.contentParts.push({ type: "data-step-separator", data: { stepIndex } });
state.currentTextPartIndex = -1;
state.currentReasoningPartIndex = -1;
}
/**
* Allowlist of tool names that should produce a UI tool card. The
* sentinel ``"all"`` matches every tool we dropped the legacy
* ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the
* generic ``ToolFallback``. The backend's ``format_thinking_step``
* summarisation and the defensive ``result_length``-only default for
* unknown tools keep persisted message JSON from ballooning.
*/
export type ToolUIGate = Set<string> | "all";
function _toolPasses(gate: ToolUIGate, toolName: string): boolean {
return gate === "all" || gate.has(toolName);
}
export function addToolCall(
state: ContentPartsState,
toolsWithUI: Set<string>,
toolsWithUI: ToolUIGate,
toolCallId: string,
toolName: string,
args: Record<string, unknown>,
force = false
force = false,
langchainToolCallId?: string
): void {
if (force || toolsWithUI.has(toolName)) {
if (force || _toolPasses(toolsWithUI, toolName)) {
state.contentParts.push({
type: "tool-call",
toolCallId,
toolName,
args,
...(langchainToolCallId ? { langchainToolCallId } : {}),
});
state.toolCallIndices.set(toolCallId, state.contentParts.length - 1);
state.currentTextPartIndex = -1;
state.currentReasoningPartIndex = -1;
}
}
/**
* Reverse-lookup helper used by the SSE ``data-action-log`` handler:
* given the LangChain ``tool_call.id`` (set on the content part as
* ``langchainToolCallId``), return the synthetic ``toolCallId`` that
* the chat tool card uses (``call_<run-id>``). Returns ``null`` when no
* matching tool card has been seen yet the action is still recorded
* in the LC-id-keyed atom so the card can pick it up when it eventually
* arrives.
*/
export function findToolCallIdByLcId(
state: ContentPartsState,
lcToolCallId: string
): string | null {
for (const part of state.contentParts) {
if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) {
return part.toolCallId;
}
}
return null;
}
export function updateToolCall(
state: ContentPartsState,
toolCallId: string,
update: { args?: Record<string, unknown>; result?: unknown }
update: {
args?: Record<string, unknown>;
argsText?: string;
result?: unknown;
langchainToolCallId?: string;
}
): void {
const index = state.toolCallIndices.get(toolCallId);
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
if (update.args) tc.args = update.args;
// ``!== undefined`` (NOT a truthy check): an explicit empty
// string CAN clear, and a finalization with
// ``JSON.stringify({}, null, 2) === "{}"`` (truthy but
// represents an empty-input call) still applies.
if (update.argsText !== undefined) tc.argsText = update.argsText;
if (update.result !== undefined) tc.result = update.result;
// Only backfill langchainToolCallId if not already set — the
// authoritative ``on_tool_end`` value should override an earlier
// best-effort match, but a NULL late-arriving value should not
// blow away a known good early one.
if (update.langchainToolCallId && !tc.langchainToolCallId) {
tc.langchainToolCallId = update.langchainToolCallId;
}
}
}
/**
* Append a streamed args-delta chunk to the active tool call's
* ``argsText``. No-ops when no card has been registered yet for the
* given ``toolCallId`` (the matching ``tool-input-start`` either lost
* the wire race or this id never had a card either way the deltas
* have nowhere safe to land).
*/
export function appendToolInputDelta(
state: ContentPartsState,
toolCallId: string,
delta: string
): void {
const idx = state.toolCallIndices.get(toolCallId);
if (idx === undefined) return;
const tc = state.contentParts[idx];
if (tc?.type !== "tool-call") return;
tc.argsText = (tc.argsText ?? "") + delta;
}
function _hasInterruptResult(part: ContentPart): boolean {
if (part.type !== "tool-call") return false;
const r = (part as { result?: unknown }).result;
@ -184,13 +353,15 @@ function _hasInterruptResult(part: ContentPart): boolean {
export function buildContentForUI(
state: ContentPartsState,
toolsWithUI: Set<string>
toolsWithUI: ToolUIGate
): ThreadMessageLike["content"] {
const filtered = state.contentParts.filter((part) => {
if (part.type === "text") return part.text.length > 0;
if (part.type === "reasoning") return part.text.length > 0;
if (part.type === "tool-call")
return toolsWithUI.has(part.toolName) || _hasInterruptResult(part);
return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part);
if (part.type === "data-thinking-steps") return true;
if (part.type === "data-step-separator") return true;
return false;
});
return filtered.length > 0
@ -200,20 +371,28 @@ export function buildContentForUI(
export function buildContentForPersistence(
state: ContentPartsState,
toolsWithUI: Set<string>
toolsWithUI: ToolUIGate
): unknown[] {
const parts: unknown[] = [];
for (const part of state.contentParts) {
if (part.type === "text" && part.text.length > 0) {
parts.push(part);
} else if (part.type === "reasoning" && part.text.length > 0) {
// Persist reasoning blocks so a chat reload re-renders the
// collapsed thinking section instead of
// silently dropping it (mirrors the data-thinking-steps
// branch above).
parts.push(part);
} else if (
part.type === "tool-call" &&
(toolsWithUI.has(part.toolName) || _hasInterruptResult(part))
(_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part))
) {
parts.push(part);
} else if (part.type === "data-thinking-steps") {
parts.push(part);
} else if (part.type === "data-step-separator") {
parts.push(part);
}
}
@ -221,23 +400,134 @@ export function buildContentForPersistence(
}
export type SSEEvent =
| { type: "text-delta"; delta: string }
| { type: "tool-input-start"; toolCallId: string; toolName: string }
| { type: "start"; messageId?: string }
| { type: "finish" }
| { type: "start-step" }
| { type: "finish-step" }
| { type: "text-start"; id: string }
| { type: "text-delta"; id?: string; delta: string }
| { type: "text-end"; id: string }
| { type: "reasoning-start"; id: string }
| { type: "reasoning-delta"; id?: string; delta: string }
| { type: "reasoning-end"; id: string }
| {
type: "tool-input-start";
toolCallId: string;
toolName: string;
/** Authoritative LangChain ``tool_call.id``. Optional. */
langchainToolCallId?: string;
}
| {
/**
* Live tool-call argument delta. Concatenated into
* ``argsText`` on the matching ``tool-call`` content part
* by ``appendToolInputDelta``. parity_v2 only the legacy
* code path emits ``tool-input-available`` without prior
* deltas.
*/
type: "tool-input-delta";
toolCallId: string;
inputTextDelta: string;
}
| {
type: "tool-input-available";
toolCallId: string;
toolName: string;
input: Record<string, unknown>;
langchainToolCallId?: string;
}
| {
type: "tool-output-available";
toolCallId: string;
output: Record<string, unknown>;
/** Authoritative LangChain ``tool_call.id`` extracted from
* ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards
* that didn't get the id at tool-input-start time. */
langchainToolCallId?: string;
}
| { type: "data-thinking-step"; data: ThinkingStepData }
| { type: "data-thread-title-update"; data: { threadId: number; title: string } }
| { type: "data-interrupt-request"; data: Record<string, unknown> }
| { type: "data-documents-updated"; data: Record<string, unknown> }
| {
/**
* A freshly committed AgentActionLog row. Frontend stores
* this in a Map keyed off ``lc_tool_call_id`` so the chat
* tool card can light up its Revert button.
*/
type: "data-action-log";
data: {
id: number;
lc_tool_call_id: string | null;
chat_turn_id: string | null;
tool_name: string;
reversible: boolean;
reverse_descriptor_present: boolean;
created_at: string | null;
error: boolean;
};
}
| {
/**
* Reversibility flipped (filesystem op SAVEPOINT committed;
* cf. ``kb_persistence._dispatch_reversibility_update``).
*/
type: "data-action-log-updated";
data: { id: number; reversible: boolean };
}
| {
/**
* Emitted at the start of every stream so the frontend can
* stamp the per-turn correlation id onto the in-flight
* assistant message and replay it via
* ``appendMessage``. Pure-text turns never produce
* action-log events; this event guarantees the frontend
* always learns the turn id.
*/
type: "data-turn-info";
data: { chat_turn_id: string };
}
| {
/**
* Best-effort revert pass that ran BEFORE this regeneration.
* Per-action results are forwarded to the UI so the user
* can see which downstream actions were rolled
* back vs which couldn't be undone.
*/
type: "data-revert-results";
data: {
status: "ok" | "partial";
chat_turn_ids: string[];
total: number;
reverted: number;
already_reverted: number;
not_reversible: number;
/**
* ``permission_denied`` and ``skipped`` are first-class
* counters so the response invariant
* ``total === sum(counters)`` always holds. Optional
* for forward compatibility with older backends; the
* frontend treats missing values as ``0``.
*/
permission_denied?: number;
failed: number;
skipped?: number;
results: Array<{
action_id: number;
tool_name: string;
status:
| "reverted"
| "already_reverted"
| "not_reversible"
| "permission_denied"
| "failed"
| "skipped";
message?: string | null;
new_action_id?: number | null;
error?: string | null;
}>;
};
}
| {
type: "data-token-usage";
data: {

View file

@ -46,6 +46,11 @@ export interface MessageRecord {
author_display_name?: string | null;
author_avatar_url?: string | null;
token_usage?: TokenUsageSummary | null;
// Per-turn correlation id from ``configurable.turn_id`` at streaming
// time (added in migration 136). Used by the per-turn revert
// endpoint and edit-from-arbitrary-position. Nullable on legacy
// rows that predate the column.
turn_id?: string | null;
}
export interface ThreadListResponse {
@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
/**
* Append a message to a thread.
*
* ``turn_id`` is the per-turn correlation id streamed by the backend
* via ``data-turn-info``. Persisting it lets later edits locate the
* matching LangGraph checkpoint without HumanMessage scanning. Older
* callers can still omit it for back-compat.
*/
export async function appendMessage(
threadId: number,
message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown }
message: {
role: "user" | "assistant" | "system";
content: unknown;
token_usage?: unknown;
turn_id?: string | null;
}
): Promise<MessageRecord> {
return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, {
body: message,

View file

@ -1,12 +1,39 @@
import { loader } from "fumadocs-core/source";
import { icons } from "lucide-react";
import {
BookOpen,
ClipboardCheck,
Compass,
Container,
Download,
FlaskConical,
Heart,
Unplug,
Wrench,
} from "lucide-react";
import { createElement } from "react";
import { docs } from "@/.source/server";
/** Explicit whitelist of Lucide icons used in docs frontmatter / meta.json.
* Importing the full `icons` barrel would pull every Lucide icon (~1 400 SVGs)
* into the docs bundle even though only a handful are referenced. Add new icons
* here as docs pages are added.
*/
const DOCS_ICONS: Record<string, React.ComponentType> = {
BookOpen,
ClipboardCheck,
Compass,
Container,
Download,
FlaskConical,
Heart,
Unplug,
Wrench,
};
export const source = loader({
baseUrl: "/docs",
source: docs.toFumadocsSource(),
icon(icon) {
if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]);
if (icon && icon in DOCS_ICONS) return createElement(DOCS_ICONS[icon]);
},
});

View file

@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages")
threadId: number().from("thread_id"),
authorId: string().optional().from("author_id"),
createdAt: number().from("created_at"),
// Per-turn correlation id sourced from ``configurable.turn_id``
// at streaming time. Required by the inline Revert button's
// (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx
// — without it the live-collab Zero sync would clobber the
// metadata we set during streaming and the button would vanish
// the moment Zero re-syncs after the stream finishes.
turnId: string().optional().from("turn_id"),
})
.primaryKey("id");