mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
feat: improved agent streaming
This commit is contained in:
parent
afb4b09cde
commit
c110f5b955
60 changed files with 8068 additions and 303 deletions
|
|
@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense
|
||||||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||||
|
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||||
|
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||||
|
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||||
|
# Schema migrations 135/136 ship unconditionally because they are
|
||||||
|
# forward-compatible.
|
||||||
|
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
# Comma-separated allowlist of plugin entry-point names
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
|
|
|
||||||
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""134_relax_revision_fks
|
||||||
|
|
||||||
|
Revision ID: 134
|
||||||
|
Revises: 133
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
|
||||||
|
revisions survive the deletes they describe.
|
||||||
|
|
||||||
|
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
|
||||||
|
hard-deleting a document via the ``rm`` tool (and likewise a
|
||||||
|
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
|
||||||
|
the snapshot row is wiped at the exact moment we need it most — revert
|
||||||
|
then has nothing to read and the operation becomes irreversible.
|
||||||
|
|
||||||
|
Migration:
|
||||||
|
|
||||||
|
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
|
||||||
|
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
|
||||||
|
* ``folder_revisions.folder_id``: same treatment.
|
||||||
|
|
||||||
|
The ``search_space_id`` FK on both tables is left unchanged (still
|
||||||
|
``ON DELETE CASCADE``). When a search space is deleted, all documents,
|
||||||
|
folders, AND their revisions go together — that's the correct teardown
|
||||||
|
story.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "134"
|
||||||
|
down_revision: str | None = "133"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _fk_name(bind, table: str, column: str) -> str | None:
|
||||||
|
"""Return the (single) FK constraint name on ``table.column``, if any."""
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for fk in inspector.get_foreign_keys(table):
|
||||||
|
cols = fk.get("constrained_columns") or []
|
||||||
|
if cols == [column]:
|
||||||
|
return fk.get("name")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> nullable + SET NULL ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
|
||||||
|
# (any revision whose parent doc/folder has already been deleted).
|
||||||
|
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
|
||||||
|
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""135_action_log_correlation_ids
|
||||||
|
|
||||||
|
Revision ID: 135
|
||||||
|
Revises: 134
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Action-log correlation-id cleanup.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
|
||||||
|
the LangChain ``tool_call.id`` into that column today (see
|
||||||
|
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
|
||||||
|
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
|
||||||
|
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
|
||||||
|
never persisted.
|
||||||
|
|
||||||
|
This migration introduces two new, correctly-named columns:
|
||||||
|
|
||||||
|
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
|
||||||
|
* ``chat_turn_id`` (the per-turn correlation id from
|
||||||
|
``configurable.turn_id`` — used by the per-turn ``revert-turn`` route).
|
||||||
|
|
||||||
|
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
|
||||||
|
existing joins keep working. The old ``turn_id`` column is left in place
|
||||||
|
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
|
||||||
|
keeps writing it (= ``tool_call_id``) for the same reason.
|
||||||
|
|
||||||
|
Indexes
|
||||||
|
-------
|
||||||
|
|
||||||
|
* ``ix_agent_action_log_tool_call_id`` — required by
|
||||||
|
``_find_action_ids_batch`` (was on ``turn_id``).
|
||||||
|
* ``ix_agent_action_log_chat_turn_id`` — required by the
|
||||||
|
``revert-turn/{chat_turn_id}`` query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "135"
|
||||||
|
down_revision: str | None = "134"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_tool_call_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["tool_call_id"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_chat_turn_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["chat_turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
|
||||||
|
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
|
||||||
|
op.drop_column("agent_action_log", "chat_turn_id")
|
||||||
|
op.drop_column("agent_action_log", "tool_call_id")
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""136_new_chat_message_turn_id
|
||||||
|
|
||||||
|
Revision ID: 136
|
||||||
|
Revises: 135
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Persist the per-turn correlation id on each chat message.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
|
||||||
|
in checkpoint metadata (see
|
||||||
|
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
|
||||||
|
support edit-from-arbitrary-position, the regenerate route needs to map
|
||||||
|
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
|
||||||
|
this column the mapping doesn't exist anywhere, so regenerate would
|
||||||
|
have to hardcode the "last 2 messages" rewind heuristic.
|
||||||
|
|
||||||
|
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
|
||||||
|
plus an index. Legacy rows have NULL — the regenerate route degrades
|
||||||
|
gracefully to the reload-last-two heuristic for those.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "136"
|
||||||
|
down_revision: str | None = "135"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"new_chat_messages",
|
||||||
|
sa.Column("turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_new_chat_messages_turn_id",
|
||||||
|
"new_chat_messages",
|
||||||
|
["turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
|
||||||
|
op.drop_column("new_chat_messages", "turn_id")
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""137_unique_reverse_of_in_action_log
|
||||||
|
|
||||||
|
Revision ID: 137
|
||||||
|
Revises: 136
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Protect ``agent_action_log.reverse_of`` against double inserts. Two
|
||||||
|
concurrent revert calls (single-action route + the per-turn batch
|
||||||
|
route, or two batch routes racing) both pass the
|
||||||
|
``_was_already_reverted`` SELECT and both insert their own
|
||||||
|
``_revert:*`` rows. The application-level idempotency check is racy
|
||||||
|
because there's no DB constraint backing it.
|
||||||
|
|
||||||
|
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
|
||||||
|
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
|
||||||
|
``IntegrityError`` and the route can translate it to ``"already_reverted"``
|
||||||
|
deterministically.
|
||||||
|
|
||||||
|
The plain ``UniqueConstraint`` flavour can't be used because most
|
||||||
|
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
|
||||||
|
and Postgres does treat NULL as distinct in unique indexes — but a
|
||||||
|
partial index is the cleanest expression of intent and works even on
|
||||||
|
older Postgres releases that distinguish NULL handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "137"
|
||||||
|
down_revision: str | None = "136"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
_INDEX_NAME = "ux_agent_action_log_reverse_of"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Defensively de-dup any pre-existing double-revert rows before
|
||||||
|
# adding the unique index. Keeps the OLDEST row (smallest id) and
|
||||||
|
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
|
||||||
|
# trail but no longer claim to be the canonical revert. We do NOT
|
||||||
|
# delete them — operators can still inspect them via /actions.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
WITH dups AS (
|
||||||
|
SELECT id,
|
||||||
|
reverse_of,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY reverse_of ORDER BY id ASC
|
||||||
|
) AS rn
|
||||||
|
FROM agent_action_log
|
||||||
|
WHERE reverse_of IS NOT NULL
|
||||||
|
)
|
||||||
|
UPDATE agent_action_log
|
||||||
|
SET reverse_of = NULL
|
||||||
|
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
_INDEX_NAME,
|
||||||
|
"agent_action_log",
|
||||||
|
["reverse_of"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where="reverse_of IS NOT NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(_INDEX_NAME, table_name="agent_action_log")
|
||||||
|
|
@ -724,7 +724,8 @@ def _build_compiled_agent_blocking(
|
||||||
repair_mw = None
|
repair_mw = None
|
||||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||||
registered_names: set[str] = {t.name for t in tools}
|
registered_names: set[str] = {t.name for t in tools}
|
||||||
# Tools owned by the standard deepagents middleware stack.
|
# Tools owned by the standard deepagents middleware stack and the
|
||||||
|
# SurfSense filesystem extension.
|
||||||
registered_names |= {
|
registered_names |= {
|
||||||
"write_todos",
|
"write_todos",
|
||||||
"ls",
|
"ls",
|
||||||
|
|
@ -735,6 +736,14 @@ def _build_compiled_agent_blocking(
|
||||||
"grep",
|
"grep",
|
||||||
"execute",
|
"execute",
|
||||||
"task",
|
"task",
|
||||||
|
"mkdir",
|
||||||
|
"cd",
|
||||||
|
"pwd",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
|
"list_tree",
|
||||||
|
"execute_code",
|
||||||
}
|
}
|
||||||
repair_mw = ToolCallNameRepairMiddleware(
|
repair_mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names=registered_names,
|
registered_tool_names=registered_names,
|
||||||
|
|
@ -763,25 +772,51 @@ def _build_compiled_agent_blocking(
|
||||||
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||||
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||||
# reject decision into innocent calls.
|
# reject decision into innocent calls.
|
||||||
# 2. ``connector_synthesized`` — deny rules for tools whose required
|
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
|
||||||
# connector is not connected to this space. Overrides #1.
|
# the agent is operating against the user's real disk. Cloud mode
|
||||||
# 3. (future) user-defined rules from ``agent_permission_rules`` table
|
# has full revision-based revert via ``revert_service``, but
|
||||||
# via the Agent Permissions UI. Loaded last so they override both.
|
# desktop mode hits disk immediately with no undo, so an
|
||||||
|
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
|
||||||
|
# ``write_file`` is unrecoverable. This layer is forced on in
|
||||||
|
# desktop mode regardless of ``enable_permission`` because the
|
||||||
|
# safety net is non-negotiable.
|
||||||
|
# 3. ``connector_synthesized`` — deny rules for tools whose required
|
||||||
|
# connector is not connected to this space. Overrides #1/#2.
|
||||||
|
# 4. (future) user-defined rules from ``agent_permission_rules`` table
|
||||||
|
# via the Agent Permissions UI. Loaded last so they override all.
|
||||||
permission_mw: PermissionMiddleware | None = None
|
permission_mw: PermissionMiddleware | None = None
|
||||||
if flags.enable_permission and not flags.disable_new_agent_stack:
|
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
synthesized = _synthesize_connector_deny_rules(
|
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||||
available_connectors=available_connectors,
|
# Build the middleware whenever it has work to do: either the user
|
||||||
enabled_tool_names={t.name for t in tools},
|
# opted into the rule engine, OR we're in desktop mode and need the
|
||||||
)
|
# safety rules unconditionally.
|
||||||
permission_mw = PermissionMiddleware(
|
if permission_enabled or is_desktop_fs:
|
||||||
rulesets=[
|
rulesets: list[Ruleset] = [
|
||||||
|
Ruleset(
|
||||||
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
|
origin="surfsense_defaults",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if is_desktop_fs:
|
||||||
|
rulesets.append(
|
||||||
Ruleset(
|
Ruleset(
|
||||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
rules=[
|
||||||
origin="surfsense_defaults",
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
),
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
Ruleset(rules=synthesized, origin="connector_synthesized"),
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
],
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
)
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if permission_enabled:
|
||||||
|
synthesized = _synthesize_connector_deny_rules(
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
enabled_tool_names={t.name for t in tools},
|
||||||
|
)
|
||||||
|
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||||
|
permission_mw = PermissionMiddleware(rulesets=rulesets)
|
||||||
|
|
||||||
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
||||||
# table is migrated. When enabled, persists one row per tool call
|
# table is migrated. When enabled, persists one row per tool call
|
||||||
|
|
@ -938,6 +973,7 @@ def _build_compiled_agent_blocking(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
|
thread_id=thread_id,
|
||||||
)
|
)
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
else None,
|
else None,
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
|
||||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
|
@ -86,6 +87,15 @@ class AgentFeatureFlags:
|
||||||
False # Backend ships before UI; route returns 503 until this flips
|
False # Backend ships before UI; route returns 503 until this flips
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured
|
||||||
|
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||||
|
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||||
|
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||||
|
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||||
|
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||||
|
# ship unconditionally because they're forward-compatible.
|
||||||
|
enable_stream_parity_v2: bool = False
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
||||||
|
|
@ -139,6 +149,10 @@ class AgentFeatureFlags:
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||||
|
# Streaming parity v2
|
||||||
|
enable_stream_parity_v2=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
||||||
|
),
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
# Observability
|
# Observability
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||||
|
|
||||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||||
|
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||||
|
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||||
|
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||||
|
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||||
|
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||||
|
|
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PendingMove(TypedDict):
|
class PendingMove(TypedDict, total=False):
|
||||||
"""A staged move_file operation pending end-of-turn commit."""
|
"""A staged move_file operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||||
|
written before the snapshot/revert pipeline was wired up; new entries
|
||||||
|
always include it so the persistence body can resolve an action_id.
|
||||||
|
"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
dest: str
|
dest: str
|
||||||
overwrite: bool
|
overwrite: bool
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDelete(TypedDict, total=False):
|
||||||
|
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is required for new entries (it's the binding key used
|
||||||
|
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||||
|
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||||
|
``total=False`` only to tolerate older checkpoint payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class KbPriorityEntry(TypedDict, total=False):
|
class KbPriorityEntry(TypedDict, total=False):
|
||||||
|
|
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||||
|
|
||||||
|
staged_dir_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||||
|
|
||||||
|
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||||
|
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||||
|
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||||
|
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||||
|
|
||||||
|
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||||
|
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||||
|
|
||||||
|
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||||
|
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||||
|
``tool_call_id`` per occurrence so snapshot binding works).
|
||||||
|
"""
|
||||||
|
|
||||||
|
pending_dir_deletes: NotRequired[
|
||||||
|
Annotated[list[PendingDelete], _list_append_reducer]
|
||||||
|
]
|
||||||
|
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||||
|
|
||||||
|
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||||
|
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||||
|
for) before issuing the DELETE.
|
||||||
|
"""
|
||||||
|
|
||||||
doc_id_by_path: NotRequired[
|
doc_id_by_path: NotRequired[
|
||||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||||
]
|
]
|
||||||
|
|
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||||
|
|
||||||
|
dirty_path_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||||
|
|
||||||
|
The persistence body coalesces multiple writes/edits to the same path
|
||||||
|
into one snapshot per turn. This map captures the most-recent
|
||||||
|
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||||
|
to the latest action_id (the one the user is most likely to revert).
|
||||||
|
"""
|
||||||
|
|
||||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||||
|
|
||||||
|
|
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"KbAnonDoc",
|
"KbAnonDoc",
|
||||||
"KbPriorityEntry",
|
"KbPriorityEntry",
|
||||||
|
"PendingDelete",
|
||||||
"PendingMove",
|
"PendingMove",
|
||||||
"SurfSenseFilesystemState",
|
"SurfSenseFilesystemState",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.callbacks import adispatch_custom_event
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
|
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_call_id = _resolve_tool_call_id(request)
|
||||||
|
chat_turn_id = _resolve_chat_turn_id(request)
|
||||||
|
|
||||||
row = AgentActionLog(
|
row = AgentActionLog(
|
||||||
thread_id=self._thread_id,
|
thread_id=self._thread_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
search_space_id=self._search_space_id,
|
search_space_id=self._search_space_id,
|
||||||
turn_id=_resolve_turn_id(request),
|
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||||
|
# kept for one release for safe rollback. New consumers
|
||||||
|
# should read ``tool_call_id`` directly.
|
||||||
|
turn_id=tool_call_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
message_id=_resolve_message_id(request),
|
message_id=_resolve_message_id(request),
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
args=args_payload,
|
args=args_payload,
|
||||||
|
|
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
session.add(row)
|
session.add(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
row_id = int(row.id) if row.id is not None else None
|
||||||
|
row_created_at = row.created_at
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ActionLogMiddleware failed to persist action log row",
|
"ActionLogMiddleware failed to persist action log row",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Surface a side-channel SSE event so the chat tool card can
|
||||||
|
# render a Revert button immediately after the row is durable.
|
||||||
|
# ``stream_new_chat`` translates this into a
|
||||||
|
# ``data-action-log`` SSE event. We DO NOT include the
|
||||||
|
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||||
|
try:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"action_log",
|
||||||
|
{
|
||||||
|
"id": row_id,
|
||||||
|
"lc_tool_call_id": tool_call_id,
|
||||||
|
"chat_turn_id": chat_turn_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"reversible": bool(reversible),
|
||||||
|
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||||
|
"created_at": row_created_at.isoformat()
|
||||||
|
if row_created_at
|
||||||
|
else None,
|
||||||
|
"error": error_payload is not None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"ActionLogMiddleware failed to dispatch action_log event",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _render_reverse(
|
def _render_reverse(
|
||||||
self,
|
self,
|
||||||
|
|
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_turn_id(request: Any) -> str | None:
|
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||||
|
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||||
try:
|
try:
|
||||||
call = getattr(request, "tool_call", None) or {}
|
call = getattr(request, "tool_call", None) or {}
|
||||||
if isinstance(call, dict):
|
if isinstance(call, dict):
|
||||||
|
|
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated alias kept for one release. Old callers and tests treated
|
||||||
|
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||||
|
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||||
|
_resolve_turn_id = _resolve_tool_call_id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||||
|
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||||
|
|
||||||
|
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||||
|
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||||
|
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
if runtime is None:
|
||||||
|
return None
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return None
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if not isinstance(configurable, dict):
|
||||||
|
return None
|
||||||
|
value = configurable.get("turn_id")
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_message_id(request: Any) -> str | None:
|
def _resolve_message_id(request: Any) -> str | None:
|
||||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||||
return _resolve_turn_id(request)
|
return _resolve_tool_call_id(request)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_result_id(result: Any) -> str | None:
|
def _resolve_result_id(result: Any) -> str | None:
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
|
||||||
- cd(path): change the current working directory.
|
- cd(path): change the current working directory.
|
||||||
- pwd(): print the current working directory.
|
- pwd(): print the current working directory.
|
||||||
- move_file(source, dest): move/rename a file under `/documents/`.
|
- move_file(source, dest): move/rename a file under `/documents/`.
|
||||||
|
- rm(path): delete a single file under `/documents/` (no `-r`).
|
||||||
|
- rmdir(path): delete an empty directory under `/documents/`.
|
||||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||||
|
|
||||||
## Persistence Rules
|
## Persistence Rules
|
||||||
|
|
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
|
||||||
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
||||||
prefix for any scratch/working content you do NOT want saved.
|
prefix for any scratch/working content you do NOT want saved.
|
||||||
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
||||||
- mkdir/move_file are staged this turn and committed at end of turn alongside
|
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
|
||||||
any new/edited documents.
|
turn alongside any new/edited documents. Snapshot/revert is enabled
|
||||||
|
for every destructive operation when action logging is on.
|
||||||
|
|
||||||
## Reading Documents Efficiently
|
## Reading Documents Efficiently
|
||||||
|
|
||||||
|
|
@ -176,6 +179,8 @@ directory (`cwd`).
|
||||||
- cd(path): change the current working directory.
|
- cd(path): change the current working directory.
|
||||||
- pwd(): print the current working directory.
|
- pwd(): print the current working directory.
|
||||||
- move_file(source, dest): move/rename a file.
|
- move_file(source, dest): move/rename a file.
|
||||||
|
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
|
||||||
|
- rmdir(path): delete an empty directory from disk. NOT reversible.
|
||||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||||
|
|
||||||
## Workflow Tips
|
## Workflow Tips
|
||||||
|
|
@ -184,6 +189,8 @@ directory (`cwd`).
|
||||||
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
||||||
brute-force directory traversal.
|
brute-force directory traversal.
|
||||||
- Cross-mount moves are not supported.
|
- Cross-mount moves are not supported.
|
||||||
|
- Desktop deletes hit disk immediately and cannot be undone via the
|
||||||
|
agent's revert flow — confirm before calling `rm`/`rmdir`.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -355,6 +362,42 @@ Notes:
|
||||||
- Parent folders are created as needed.
|
- Parent folders are created as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
|
||||||
|
|
||||||
|
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
|
||||||
|
for end-of-turn commit; the row is removed only after the agent's turn
|
||||||
|
finishes successfully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute or relative file path. Cannot point at a directory — use
|
||||||
|
`rmdir` for empty folders. Cannot target the root or `/documents`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The action is reversible via the per-action revert flow when action
|
||||||
|
logging is enabled.
|
||||||
|
- The anonymous uploaded document is read-only and cannot be deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
|
||||||
|
|
||||||
|
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||||
|
deletion (`rm -r`) is intentionally NOT supported — clear contents with
|
||||||
|
`rm` first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute or relative directory path. Cannot target the root,
|
||||||
|
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
|
||||||
|
move out first).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Emptiness is evaluated against the post-staged view, so a same-turn
|
||||||
|
`rm /a/x.md` followed by `rmdir /a` is fine.
|
||||||
|
- If the directory was added in this same turn via `mkdir` and never
|
||||||
|
committed, the staged mkdir is dropped instead of issuing a delete.
|
||||||
|
- The action is reversible via the per-action revert flow when action
|
||||||
|
logging is enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
# --- desktop-only ----------------------------------------------------------
|
# --- desktop-only ----------------------------------------------------------
|
||||||
|
|
||||||
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||||
|
|
@ -421,6 +464,28 @@ Notes:
|
||||||
- Parent folders are created as needed.
|
- Parent folders are created as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
|
||||||
|
|
||||||
|
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
|
||||||
|
disk immediately. Desktop deletes are NOT reversible via the agent's
|
||||||
|
revert flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute mount-prefixed file path. Cannot point at a directory —
|
||||||
|
use `rmdir` for empty folders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
|
||||||
|
|
||||||
|
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||||
|
deletion is NOT supported. The deletion hits disk immediately and is
|
||||||
|
NOT reversible via the agent's revert flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute mount-prefixed directory path. Cannot target the mount
|
||||||
|
root or any directory containing files/subfolders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"""Pick the active-mode description for every filesystem tool."""
|
"""Pick the active-mode description for every filesystem tool."""
|
||||||
|
|
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
||||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||||
|
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
|
||||||
|
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
||||||
|
|
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
||||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||||
|
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
|
||||||
|
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
|
||||||
return path.rsplit("/", 1)[-1]
|
return path.rsplit("/", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ancestor_of(candidate: str, target: str) -> bool:
|
||||||
|
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
|
||||||
|
|
||||||
|
``target`` itself is NOT considered an ancestor (use equality for that).
|
||||||
|
Both paths are assumed to be canonicalised, absolute, and free of
|
||||||
|
trailing slashes (except the root ``/``).
|
||||||
|
"""
|
||||||
|
if not candidate.startswith("/") or not target.startswith("/"):
|
||||||
|
return False
|
||||||
|
if candidate == target:
|
||||||
|
return False
|
||||||
|
prefix = candidate.rstrip("/") + "/"
|
||||||
|
return target.startswith(prefix)
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
||||||
|
|
||||||
|
|
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
self.tools.append(self._create_cd_tool())
|
self.tools.append(self._create_cd_tool())
|
||||||
self.tools.append(self._create_pwd_tool())
|
self.tools.append(self._create_pwd_tool())
|
||||||
self.tools.append(self._create_move_file_tool())
|
self.tools.append(self._create_move_file_tool())
|
||||||
|
self.tools.append(self._create_rm_tool())
|
||||||
|
self.tools.append(self._create_rmdir_tool())
|
||||||
self.tools.append(self._create_list_tree_tool())
|
self.tools.append(self._create_list_tree_tool())
|
||||||
if self._sandbox_available:
|
if self._sandbox_available:
|
||||||
self.tools.append(self._create_execute_code_tool())
|
self.tools.append(self._create_execute_code_tool())
|
||||||
|
|
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
}
|
}
|
||||||
if self._is_cloud():
|
if self._is_cloud():
|
||||||
update["dirty_paths"] = [path]
|
update["dirty_paths"] = [path]
|
||||||
|
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
||||||
def sync_write_file(
|
def sync_write_file(
|
||||||
|
|
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
}
|
}
|
||||||
if self._is_cloud():
|
if self._is_cloud():
|
||||||
update["dirty_paths"] = [path]
|
update["dirty_paths"] = [path]
|
||||||
|
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||||
if doc_id_to_attach is not None:
|
if doc_id_to_attach is not None:
|
||||||
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
"staged_dirs": [validated],
|
"staged_dirs": [validated],
|
||||||
|
"staged_dir_tool_calls": {
|
||||||
|
validated: runtime.tool_call_id,
|
||||||
|
},
|
||||||
"messages": [
|
"messages": [
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=(
|
content=(
|
||||||
|
|
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
||||||
update: dict[str, Any] = {
|
update: dict[str, Any] = {
|
||||||
"files": files_update,
|
"files": files_update,
|
||||||
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
|
"pending_moves": [
|
||||||
|
{
|
||||||
|
"source": source,
|
||||||
|
"dest": dest,
|
||||||
|
"overwrite": False,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
"messages": [
|
"messages": [
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=(
|
content=(
|
||||||
|
|
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
update["dirty_paths"] = new_dirty
|
update["dirty_paths"] = new_dirty
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ tool: rm
|
||||||
|
|
||||||
|
def _create_rm_tool(self) -> BaseTool:
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_rm(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path to the file to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
if not path or not path.strip():
|
||||||
|
return "Error: path is required."
|
||||||
|
|
||||||
|
target = self._resolve_relative(path, runtime)
|
||||||
|
try:
|
||||||
|
validated = validate_path(target)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
if self._is_cloud():
|
||||||
|
if validated in ("/", DOCUMENTS_ROOT):
|
||||||
|
return f"Error: refusing to rm '{validated}'."
|
||||||
|
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
return (
|
||||||
|
"Error: cloud rm must target a path under /documents/ "
|
||||||
|
f"(got '{validated}')."
|
||||||
|
)
|
||||||
|
|
||||||
|
anon = runtime.state.get("kb_anon_doc") or {}
|
||||||
|
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
|
||||||
|
return "Error: the anonymous uploaded document is read-only."
|
||||||
|
|
||||||
|
# Refuse if the path looks like a directory.
|
||||||
|
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||||
|
if validated in staged_dirs:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||||
|
"empty directories."
|
||||||
|
)
|
||||||
|
pending_dir_deletes = list(
|
||||||
|
runtime.state.get("pending_dir_deletes") or []
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_dir_deletes
|
||||||
|
):
|
||||||
|
return f"Error: '{validated}' is already queued for rmdir."
|
||||||
|
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
if isinstance(backend, KBPostgresBackend):
|
||||||
|
# Detect "is a directory" via `ls`: if the path lists
|
||||||
|
# children we know it's a folder. Otherwise we still
|
||||||
|
# need to confirm it's a real file before staging.
|
||||||
|
children = await backend.als_info(validated)
|
||||||
|
if children:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||||
|
"empty directories."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Already queued for delete this turn?
|
||||||
|
pending_deletes = list(runtime.state.get("pending_deletes") or [])
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_deletes
|
||||||
|
):
|
||||||
|
return f"'{validated}' is already queued for deletion."
|
||||||
|
|
||||||
|
# Resolve doc_id (best-effort): file in state or DB.
|
||||||
|
files_state = runtime.state.get("files") or {}
|
||||||
|
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
|
||||||
|
resolved_doc_id: int | None = doc_id_by_path.get(validated)
|
||||||
|
if (
|
||||||
|
validated not in files_state
|
||||||
|
and resolved_doc_id is None
|
||||||
|
and isinstance(backend, KBPostgresBackend)
|
||||||
|
):
|
||||||
|
loaded = await backend._load_file_data(validated)
|
||||||
|
if loaded is None:
|
||||||
|
return f"Error: file '{validated}' not found."
|
||||||
|
_, resolved_doc_id = loaded
|
||||||
|
|
||||||
|
files_update: dict[str, Any] = {validated: None}
|
||||||
|
update: dict[str, Any] = {
|
||||||
|
"pending_deletes": [
|
||||||
|
{
|
||||||
|
"path": validated,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"files": files_update,
|
||||||
|
"doc_id_by_path": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Staged delete of '{validated}' (will commit at "
|
||||||
|
"end of turn)."
|
||||||
|
),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Drop the path from dirty_paths so a same-turn write+rm
|
||||||
|
# doesn't recreate the doc at commit time.
|
||||||
|
dirty_paths = list(runtime.state.get("dirty_paths") or [])
|
||||||
|
if validated in dirty_paths:
|
||||||
|
new_dirty: list[Any] = [_CLEAR]
|
||||||
|
for entry in dirty_paths:
|
||||||
|
if entry != validated:
|
||||||
|
new_dirty.append(entry)
|
||||||
|
update["dirty_paths"] = new_dirty
|
||||||
|
update["dirty_path_tool_calls"] = {validated: None}
|
||||||
|
|
||||||
|
return Command(update=update)
|
||||||
|
|
||||||
|
# Desktop mode — hit disk immediately.
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
adelete = getattr(backend, "adelete_file", None)
|
||||||
|
if not callable(adelete):
|
||||||
|
return "Error: rm is not supported by the active backend."
|
||||||
|
res: WriteResult = await adelete(validated)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
update_desktop: dict[str, Any] = {
|
||||||
|
"files": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Deleted file '{res.path or validated}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return Command(update=update_desktop)
|
||||||
|
|
||||||
|
def sync_rm(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path to the file to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
return self._run_async_blocking(async_rm(path, runtime))
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="rm",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_rm,
|
||||||
|
coroutine=async_rm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ tool: rmdir
|
||||||
|
|
||||||
|
def _create_rmdir_tool(self) -> BaseTool:
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_rmdir(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path of the empty directory to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
if not path or not path.strip():
|
||||||
|
return "Error: path is required."
|
||||||
|
|
||||||
|
target = self._resolve_relative(path, runtime)
|
||||||
|
try:
|
||||||
|
validated = validate_path(target)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
if self._is_cloud():
|
||||||
|
if validated in ("/", DOCUMENTS_ROOT):
|
||||||
|
return f"Error: refusing to rmdir '{validated}'."
|
||||||
|
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
return (
|
||||||
|
"Error: cloud rmdir must target a path under /documents/ "
|
||||||
|
f"(got '{validated}')."
|
||||||
|
)
|
||||||
|
|
||||||
|
cwd = self._current_cwd(runtime)
|
||||||
|
if validated == cwd or _is_ancestor_of(validated, cwd):
|
||||||
|
return (
|
||||||
|
f"Error: cannot rmdir '{validated}' because the current "
|
||||||
|
"cwd is at or under it. cd out first."
|
||||||
|
)
|
||||||
|
|
||||||
|
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||||
|
pending_dir_deletes = list(
|
||||||
|
runtime.state.get("pending_dir_deletes") or []
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_dir_deletes
|
||||||
|
):
|
||||||
|
return f"'{validated}' is already queued for deletion."
|
||||||
|
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
|
||||||
|
# The path must currently exist either in DB folder paths or
|
||||||
|
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
|
||||||
|
# already accounts for pending deletes/moves) to evaluate
|
||||||
|
# both existence and emptiness against the post-staged view.
|
||||||
|
exists_in_staged = validated in staged_dirs
|
||||||
|
children: list[Any] = []
|
||||||
|
if isinstance(backend, KBPostgresBackend):
|
||||||
|
children = list(await backend.als_info(validated))
|
||||||
|
|
||||||
|
# Detect "is a file" — if als_info returns no children but
|
||||||
|
# the path is actually a file, we should reject. We use
|
||||||
|
# _load_file_data to disambiguate file vs missing folder.
|
||||||
|
if (
|
||||||
|
isinstance(backend, KBPostgresBackend)
|
||||||
|
and not children
|
||||||
|
and not exists_in_staged
|
||||||
|
):
|
||||||
|
loaded = await backend._load_file_data(validated)
|
||||||
|
if loaded is not None:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a file. Use rm to delete files."
|
||||||
|
)
|
||||||
|
# Confirm folder exists in DB by checking the parent listing.
|
||||||
|
parent = posixpath.dirname(validated) or "/"
|
||||||
|
parent_listing = await backend.als_info(parent)
|
||||||
|
parent_has_dir = any(
|
||||||
|
info.get("path") == validated and info.get("is_dir")
|
||||||
|
for info in parent_listing
|
||||||
|
)
|
||||||
|
if not parent_has_dir:
|
||||||
|
return f"Error: directory '{validated}' not found."
|
||||||
|
|
||||||
|
if children:
|
||||||
|
return (
|
||||||
|
f"Error: directory '{validated}' is not empty. "
|
||||||
|
"Remove contents first."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same-turn mkdir un-stage: drop the staged_dirs entry
|
||||||
|
# entirely and skip queuing a DB delete (nothing was ever
|
||||||
|
# committed).
|
||||||
|
if exists_in_staged:
|
||||||
|
rest = [d for d in staged_dirs if d != validated]
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"staged_dirs": [_CLEAR, *rest],
|
||||||
|
"staged_dir_tool_calls": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(f"Un-staged directory '{validated}'."),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"pending_dir_deletes": [
|
||||||
|
{
|
||||||
|
"path": validated,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Staged rmdir of '{validated}' (will commit "
|
||||||
|
"at end of turn)."
|
||||||
|
),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Desktop mode — hit disk immediately.
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
armdir = getattr(backend, "armdir", None)
|
||||||
|
if not callable(armdir):
|
||||||
|
return "Error: rmdir is not supported by the active backend."
|
||||||
|
res: WriteResult = await armdir(validated)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Deleted directory '{res.path or validated}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_rmdir(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path of the empty directory to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
return self._run_async_blocking(async_rmdir(path, runtime))
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="rmdir",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_rmdir,
|
||||||
|
coroutine=async_rmdir,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ tool: list_tree
|
# ------------------------------------------------------------------ tool: list_tree
|
||||||
|
|
||||||
def _create_list_tree_tool(self) -> BaseTool:
|
def _create_list_tree_tool(self) -> BaseTool:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
def _pending_moves(self) -> list[dict[str, Any]]:
|
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||||
return list(self.state.get("pending_moves") or [])
|
return list(self.state.get("pending_moves") or [])
|
||||||
|
|
||||||
|
def _pending_deletes(self) -> list[dict[str, Any]]:
|
||||||
|
return list(self.state.get("pending_deletes") or [])
|
||||||
|
|
||||||
|
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
|
||||||
|
return list(self.state.get("pending_dir_deletes") or [])
|
||||||
|
|
||||||
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||||
anon = self.state.get("kb_anon_doc")
|
anon = self.state.get("kb_anon_doc")
|
||||||
return anon if isinstance(anon, dict) else None
|
return anon if isinstance(anon, dict) else None
|
||||||
|
|
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
return path
|
return path
|
||||||
return path.rstrip("/") if path != "/" else path
|
return path.rstrip("/") if path != "/" else path
|
||||||
|
|
||||||
def _moved_view_paths(
|
def _pending_filesystem_view(
|
||||||
self,
|
self,
|
||||||
existing: dict[str, dict[str, Any]],
|
existing: dict[str, dict[str, Any]],
|
||||||
) -> tuple[set[str], dict[str, str]]:
|
) -> tuple[set[str], dict[str, str], set[str]]:
|
||||||
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
"""Compute removed/aliased/dir-suppressed paths from staged ops.
|
||||||
|
|
||||||
Removed paths should disappear from listings; ``alias[source] = dest``
|
Returns ``(removed, alias, deleted_dirs)`` where:
|
||||||
means a virtual entry should appear at ``dest`` even if no DB row is
|
|
||||||
yet there.
|
* ``removed`` — paths to drop from listings (sources of pending moves
|
||||||
|
AND paths queued for ``rm``).
|
||||||
|
* ``alias`` — ``{source: dest}`` for pending moves; the dest should
|
||||||
|
appear as a virtual entry even when no DB row is at that path yet.
|
||||||
|
* ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire
|
||||||
|
subtree (descendants) is suppressed from listings/glob/grep.
|
||||||
|
|
||||||
|
Entries in ``existing`` (the ``files`` state cache) keyed by a
|
||||||
|
removed path are popped so a same-turn delete-after-write doesn't
|
||||||
|
leave a stale virtual file in listings.
|
||||||
"""
|
"""
|
||||||
removed: set[str] = set()
|
removed: set[str] = set()
|
||||||
alias: dict[str, str] = {}
|
alias: dict[str, str] = {}
|
||||||
|
deleted_dirs: set[str] = set()
|
||||||
for move in self._pending_moves():
|
for move in self._pending_moves():
|
||||||
src = move.get("source")
|
src = move.get("source")
|
||||||
dst = move.get("dest")
|
dst = move.get("dest")
|
||||||
|
|
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
removed.add(src)
|
removed.add(src)
|
||||||
alias[src] = dst
|
alias[src] = dst
|
||||||
existing.pop(src, None)
|
existing.pop(src, None)
|
||||||
return removed, alias
|
for entry in self._pending_deletes():
|
||||||
|
path = entry.get("path") if isinstance(entry, dict) else None
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
removed.add(path)
|
||||||
|
existing.pop(path, None)
|
||||||
|
for entry in self._pending_dir_deletes():
|
||||||
|
path = entry.get("path") if isinstance(entry, dict) else None
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
deleted_dirs.add(path)
|
||||||
|
return removed, alias, deleted_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
|
||||||
|
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
|
||||||
|
return any(path == d or _is_under(path, d) for d in deleted_dirs)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ ls/read
|
# ------------------------------------------------------------------ ls/read
|
||||||
|
|
||||||
|
|
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
seen.add(anon_path)
|
seen.add(anon_path)
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, moved_alias = self._moved_view_paths(files)
|
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
|
|
||||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||||
try:
|
try:
|
||||||
|
|
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
|
|
||||||
for info in db_infos:
|
for info in db_infos:
|
||||||
p = info.get("path", "")
|
p = info.get("path", "")
|
||||||
if not p or p in seen or p in moved_removed:
|
if (
|
||||||
|
not p
|
||||||
|
or p in seen
|
||||||
|
or p in moved_removed
|
||||||
|
or self._is_dir_suppressed(p, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
seen.add(p)
|
seen.add(p)
|
||||||
|
|
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
if src not in seen:
|
if src not in seen:
|
||||||
if not _is_under(dst, normalized):
|
if not _is_under(dst, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(dst, deleted_dirs):
|
||||||
|
continue
|
||||||
rel = (
|
rel = (
|
||||||
dst[len(normalized) :].lstrip("/")
|
dst[len(normalized) :].lstrip("/")
|
||||||
if normalized != "/"
|
if normalized != "/"
|
||||||
|
|
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
continue
|
continue
|
||||||
if not _is_under(staged, normalized):
|
if not _is_under(staged, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||||
|
continue
|
||||||
rel = (
|
rel = (
|
||||||
staged[len(normalized) :].lstrip("/")
|
staged[len(normalized) :].lstrip("/")
|
||||||
if normalized != "/"
|
if normalized != "/"
|
||||||
|
|
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for sub in sorted(subdir_paths):
|
for sub in sorted(subdir_paths):
|
||||||
if sub in seen:
|
if sub in seen:
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(sub, deleted_dirs):
|
||||||
|
continue
|
||||||
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||||
seen.add(sub)
|
seen.add(sub)
|
||||||
|
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if not isinstance(path_key, str) or path_key in seen:
|
if not isinstance(path_key, str) or path_key in seen:
|
||||||
continue
|
continue
|
||||||
|
# Tombstones (None values) are deletion markers from `rm`. The
|
||||||
|
# deepagents reducer normally pops them, but a stale tombstone
|
||||||
|
# surviving a checkpoint must NOT be reported as a child here —
|
||||||
|
# otherwise rmdir mistakenly sees the deleted file as content.
|
||||||
|
if fd is None:
|
||||||
|
continue
|
||||||
if not _is_under(path_key, normalized) or path_key == normalized:
|
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||||
continue
|
continue
|
||||||
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
|
continue
|
||||||
if normalized == "/":
|
if normalized == "/":
|
||||||
rel = path_key.lstrip("/")
|
rel = path_key.lstrip("/")
|
||||||
else:
|
else:
|
||||||
|
|
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
regex = re.compile(fnmatch.translate(pattern))
|
regex = re.compile(fnmatch.translate(pattern))
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if path_key in moved_removed:
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
folder_id=row.folder_id,
|
folder_id=row.folder_id,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
if candidate in seen or candidate in moved_removed:
|
if (
|
||||||
|
candidate in seen
|
||||||
|
or candidate in moved_removed
|
||||||
|
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
matches: list[GrepMatch] = []
|
matches: list[GrepMatch] = []
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if path_key in moved_removed:
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
)
|
)
|
||||||
for doc_id, chunk_id, content in chunk_buffer:
|
for doc_id, chunk_id, content in chunk_buffer:
|
||||||
candidate = doc_id_to_path.get(doc_id)
|
candidate = doc_id_to_path.get(doc_id)
|
||||||
if not candidate or candidate in moved_removed:
|
if (
|
||||||
|
not candidate
|
||||||
|
or candidate in moved_removed
|
||||||
|
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
return {"entries": [], "truncated": False}
|
return {"entries": [], "truncated": False}
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
anon = self._kb_anon_doc()
|
anon = self._kb_anon_doc()
|
||||||
anon_path = str(anon.get("path") or "") if anon else ""
|
anon_path = str(anon.get("path") or "") if anon else ""
|
||||||
|
|
||||||
|
|
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||||
if not _is_under(fpath, normalized):
|
if not _is_under(fpath, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(fpath, deleted_dirs):
|
||||||
|
continue
|
||||||
depth = _depth_of(fpath)
|
depth = _depth_of(fpath)
|
||||||
if max_depth is not None and depth > max_depth:
|
if max_depth is not None and depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for staged in self._staged_dirs():
|
for staged in self._staged_dirs():
|
||||||
if not _is_under(staged, normalized):
|
if not _is_under(staged, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||||
|
continue
|
||||||
depth = _depth_of(staged)
|
depth = _depth_of(staged)
|
||||||
if max_depth is not None and depth > max_depth:
|
if max_depth is not None and depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
folder_id=row.folder_id,
|
folder_id=row.folder_id,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
if candidate in moved_removed:
|
if candidate in moved_removed or self._is_dir_suppressed(
|
||||||
|
candidate, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
|
continue
|
||||||
if any(e["path"] == path_key for e in entries):
|
if any(e["path"] == path_key for e in entries):
|
||||||
continue
|
continue
|
||||||
if not (
|
if not (
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||||
|
|
||||||
|
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||||
|
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||||
|
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||||
|
# infer emptiness from indentation alone.
|
||||||
|
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||||
|
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
for path in all_paths:
|
for path in all_paths:
|
||||||
depth = (
|
depth = (
|
||||||
|
|
@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||||
)
|
)
|
||||||
if is_dir:
|
if is_dir:
|
||||||
lines.append(f"{indent}{display}/")
|
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||||
|
lines.append(f"{indent}{display}/ (empty)")
|
||||||
|
else:
|
||||||
|
lines.append(f"{indent}{display}/")
|
||||||
else:
|
else:
|
||||||
lines.append(f"{indent}{display}")
|
lines.append(f"{indent}{display}")
|
||||||
if len(lines) >= self.max_entries:
|
if len(lines) >= self.max_entries:
|
||||||
|
|
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
return self._format_root_summary(folder_paths, doc_paths)
|
return self._format_root_summary(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_non_empty_folders(
|
||||||
|
folder_paths: list[str], doc_paths: list[str]
|
||||||
|
) -> set[str]:
|
||||||
|
"""Return the set of folder paths that contain at least one descendant.
|
||||||
|
|
||||||
|
A folder is "non-empty" if any document path or any other folder path
|
||||||
|
is strictly under it. Documents propagate emptiness up to every
|
||||||
|
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||||
|
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||||
|
"""
|
||||||
|
non_empty: set[str] = set()
|
||||||
|
folder_set = set(folder_paths)
|
||||||
|
|
||||||
|
for doc_path in doc_paths:
|
||||||
|
parent = doc_path.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT:
|
||||||
|
if parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
for child in folder_paths:
|
||||||
|
parent = child.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
return non_empty
|
||||||
|
|
||||||
def _format_root_summary(
|
def _format_root_summary(
|
||||||
self, folder_paths: list[str], doc_paths: list[str]
|
self, folder_paths: list[str], doc_paths: list[str]
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
||||||
|
|
@ -360,6 +360,74 @@ class LocalFolderBackend:
|
||||||
self.move, source_path, destination_path, overwrite
|
self.move, source_path, destination_path, overwrite
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete a single file under root.
|
||||||
|
|
||||||
|
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||||
|
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||||
|
out of scope.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
with self._lock_for(file_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
if path.is_dir():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: '{file_path}' is a directory. "
|
||||||
|
"Use rmdir for empty directories."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.unlink(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(
|
||||||
|
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||||
|
)
|
||||||
|
return WriteResult(path=file_path, files_update=None)
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete an empty directory under root.
|
||||||
|
|
||||||
|
Refuses files, root, missing paths, and non-empty directories.
|
||||||
|
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||||
|
clearer for the agent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(dir_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||||
|
with self._lock_for(dir_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||||
|
if not path.is_dir():
|
||||||
|
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||||
|
try:
|
||||||
|
next(path.iterdir())
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: directory '{dir_path}' is not empty. "
|
||||||
|
"Remove its contents first."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.rmdir(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||||
|
return WriteResult(path=dir_path, files_update=None)
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
def edit(
|
def edit(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
|
|
|
||||||
|
|
@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
|
||||||
overwrite,
|
overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(dir_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
if local_path == "/":
|
||||||
|
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||||
|
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
def edit(
|
def edit(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
|
|
|
||||||
|
|
@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"cwd": "/documents",
|
"cwd": "/documents",
|
||||||
"staged_dirs": [],
|
"staged_dirs": [],
|
||||||
|
"staged_dir_tool_calls": {},
|
||||||
"pending_moves": [],
|
"pending_moves": [],
|
||||||
|
"pending_deletes": [],
|
||||||
|
"pending_dir_deletes": [],
|
||||||
"doc_id_by_path": {},
|
"doc_id_by_path": {},
|
||||||
"dirty_paths": [],
|
"dirty_paths": [],
|
||||||
|
"dirty_path_tool_calls": {},
|
||||||
"kb_priority": [],
|
"kb_priority": [],
|
||||||
"kb_matched_chunk_ids": {},
|
"kb_matched_chunk_ids": {},
|
||||||
"kb_anon_doc": None,
|
"kb_anon_doc": None,
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
||||||
"write_file",
|
"write_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"update_memory",
|
"update_memory",
|
||||||
"update_memory_team",
|
"update_memory_team",
|
||||||
"update_memory_private",
|
"update_memory_private",
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,35 @@ from langgraph.types import interrupt
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Tools that mirror the safety profile of ``write_file`` against the
|
||||||
|
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
||||||
|
# with no external visibility (drafts aren't sent; new files aren't shared
|
||||||
|
# unless the user shares them later). These are auto-approved by default
|
||||||
|
# so the agent can compose drafts and seed scratch files without a popup
|
||||||
|
# on every call.
|
||||||
|
#
|
||||||
|
# Members of this set still call ``request_approval`` exactly as before;
|
||||||
|
# the function returns immediately with ``decision_type="auto_approved"``
|
||||||
|
# and the original params untouched. This preserves the call-site shape
|
||||||
|
# (logging, metadata fetching, account fallbacks) so the only behavior
|
||||||
|
# change is "no interrupt fires".
|
||||||
|
#
|
||||||
|
# To re-enable prompting, the future per-search-space rules table
|
||||||
|
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
|
||||||
|
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||||
|
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"create_gmail_draft",
|
||||||
|
"update_gmail_draft",
|
||||||
|
"create_notion_page",
|
||||||
|
"create_confluence_page",
|
||||||
|
"create_google_drive_file",
|
||||||
|
"create_dropbox_file",
|
||||||
|
"create_onedrive_file",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class HITLResult:
|
class HITLResult:
|
||||||
"""Outcome of a human-in-the-loop approval request."""
|
"""Outcome of a human-in-the-loop approval request."""
|
||||||
|
|
@ -119,6 +148,19 @@ def request_approval(
|
||||||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||||
|
|
||||||
|
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||||
|
# Default policy: low-stakes creation tools (drafts + new-file
|
||||||
|
# creates) skip HITL because they're as recoverable as a local
|
||||||
|
# ``write_file`` against the SurfSense KB. The user can still
|
||||||
|
# delete the artifact in <30s if it's wrong.
|
||||||
|
logger.info(
|
||||||
|
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
return HITLResult(
|
||||||
|
rejected=False, decision_type="auto_approved", params=dict(params)
|
||||||
|
)
|
||||||
|
|
||||||
approval = interrupt(
|
approval = interrupt(
|
||||||
{
|
{
|
||||||
"type": action_type,
|
"type": action_type,
|
||||||
|
|
|
||||||
|
|
@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-turn correlation id sourced from ``configurable.turn_id`` at
|
||||||
|
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
|
||||||
|
# predate the column. Used by C1's edit-from-arbitrary-position to map
|
||||||
|
# a message back to the LangGraph checkpoint that produced its turn.
|
||||||
|
turn_id = Column(String(64), nullable=True, index=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
thread = relationship("NewChatThread", back_populates="messages")
|
thread = relationship("NewChatThread", back_populates="messages")
|
||||||
author = relationship("User")
|
author = relationship("User")
|
||||||
|
|
@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
|
||||||
|
# been renamed to ``tool_call_id`` (with a parallel column kept for one
|
||||||
|
# release for back-compat). The real chat-turn id lives in
|
||||||
|
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
|
||||||
turn_id = Column(String(64), nullable=True, index=True)
|
turn_id = Column(String(64), nullable=True, index=True)
|
||||||
|
tool_call_id = Column(String(64), nullable=True, index=True)
|
||||||
|
chat_turn_id = Column(String(64), nullable=True, index=True)
|
||||||
message_id = Column(String(128), nullable=True, index=True)
|
message_id = Column(String(128), nullable=True, index=True)
|
||||||
tool_name = Column(String(255), nullable=False, index=True)
|
tool_name = Column(String(255), nullable=False, index=True)
|
||||||
args = Column(JSONB, nullable=True)
|
args = Column(JSONB, nullable=True)
|
||||||
|
|
@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel):
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
||||||
|
# Partial unique index enforces "at most one revert per
|
||||||
|
# original action". Created in migration 137 with
|
||||||
|
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
|
||||||
|
# (the vast majority) are unaffected and NULLs don't collide.
|
||||||
|
Index(
|
||||||
|
"ux_agent_action_log_reverse_of",
|
||||||
|
"reverse_of",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("reverse_of IS NOT NULL"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel):
|
||||||
|
|
||||||
__tablename__ = "document_revisions"
|
__tablename__ = "document_revisions"
|
||||||
|
|
||||||
|
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||||
|
# hard-delete it describes — without that, ``rm`` would wipe the row
|
||||||
|
# we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||||
document_id = Column(
|
document_id = Column(
|
||||||
Integer,
|
Integer,
|
||||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
ForeignKey("documents.id", ondelete="SET NULL"),
|
||||||
nullable=False,
|
nullable=True,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
search_space_id = Column(
|
search_space_id = Column(
|
||||||
|
|
@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel):
|
||||||
|
|
||||||
__tablename__ = "folder_revisions"
|
__tablename__ = "folder_revisions"
|
||||||
|
|
||||||
|
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||||
|
# hard-delete it describes — without that, ``rmdir`` would wipe the
|
||||||
|
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||||
folder_id = Column(
|
folder_id = Column(
|
||||||
Integer,
|
Integer,
|
||||||
ForeignKey("folders.id", ondelete="CASCADE"),
|
ForeignKey("folders.id", ondelete="SET NULL"),
|
||||||
nullable=False,
|
nullable=True,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
search_space_id = Column(
|
search_space_id = Column(
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
|
||||||
reverse_of: int | None
|
reverse_of: int | None
|
||||||
reverted_by_action_id: int | None
|
reverted_by_action_id: int | None
|
||||||
is_revert_action: bool
|
is_revert_action: bool
|
||||||
|
# Correlation ids added in migration 135. ``tool_call_id`` is the
|
||||||
|
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
|
||||||
|
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
|
||||||
|
# from ``configurable.turn_id`` (used by the
|
||||||
|
# ``revert-turn/{chat_turn_id}`` endpoint).
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
chat_turn_id: str | None = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -172,6 +179,8 @@ async def list_thread_actions(
|
||||||
reverse_of=row.reverse_of,
|
reverse_of=row.reverse_of,
|
||||||
reverted_by_action_id=revert_map.get(row.id),
|
reverted_by_action_id=revert_map.get(row.id),
|
||||||
is_revert_action=row.reverse_of is not None,
|
is_revert_action=row.reverse_of is not None,
|
||||||
|
tool_call_id=row.tool_call_id,
|
||||||
|
chat_turn_id=row.chat_turn_id,
|
||||||
created_at=row.created_at,
|
created_at=row.created_at,
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
|
||||||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||||
5. Idempotent on retries: if the same action is reverted twice the second
|
5. Idempotent on retries: if the same action is reverted twice the second
|
||||||
call returns 409 ``"already reverted"``.
|
call returns 409 ``"already reverted"``.
|
||||||
|
|
||||||
|
This module also hosts the per-turn batch endpoint
|
||||||
|
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
|
||||||
|
walks every reversible action emitted during a chat turn in reverse
|
||||||
|
``created_at`` order and reverts each independently. Partial success is the
|
||||||
|
common case — the response always contains a per-action result list and a
|
||||||
|
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
|
||||||
|
whole-batch 4xx.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
|
@ -97,6 +108,16 @@ async def revert_agent_action(
|
||||||
action=action,
|
action=action,
|
||||||
requester_user_id=str(user.id) if user is not None else None,
|
requester_user_id=str(user.id) if user is not None else None,
|
||||||
)
|
)
|
||||||
|
except IntegrityError:
|
||||||
|
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
|
||||||
|
# a concurrent revert. Translate to the existing 409 "already
|
||||||
|
# reverted" contract so racing clients see consistent
|
||||||
|
# behaviour with the pre-flight TOCTOU check above.
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="This action has already been reverted.",
|
||||||
|
) from None
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
|
@ -105,7 +126,16 @@ async def revert_agent_action(
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
if outcome.status == "ok":
|
if outcome.status == "ok":
|
||||||
await session.commit()
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
# Race lost on commit (constraint enforced at flush in some
|
||||||
|
# configs but at commit in others — defensive).
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="This action has already been reverted.",
|
||||||
|
) from None
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"message": outcome.message,
|
"message": outcome.message,
|
||||||
|
|
@ -122,3 +152,357 @@ async def revert_agent_action(
|
||||||
raise HTTPException(status_code=501, detail=outcome.message)
|
raise HTTPException(status_code=501, detail=outcome.message)
|
||||||
# not_reversible
|
# not_reversible
|
||||||
raise HTTPException(status_code=409, detail=outcome.message)
|
raise HTTPException(status_code=409, detail=outcome.message)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-turn revert batch endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
PerActionStatus = Literal[
|
||||||
|
"reverted",
|
||||||
|
"already_reverted",
|
||||||
|
"not_reversible",
|
||||||
|
"permission_denied",
|
||||||
|
"failed",
|
||||||
|
"skipped",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RevertTurnActionResult(BaseModel):
|
||||||
|
"""Per-action outcome inside a ``revert-turn`` batch response."""
|
||||||
|
|
||||||
|
action_id: int
|
||||||
|
tool_name: str
|
||||||
|
status: PerActionStatus
|
||||||
|
message: str | None = None
|
||||||
|
new_action_id: int | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RevertTurnResponse(BaseModel):
|
||||||
|
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
|
||||||
|
``status`` is ``"ok"`` only when every reversible row succeeded. Any
|
||||||
|
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
|
||||||
|
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
|
||||||
|
``results`` list — callers should treat that as a no-op.
|
||||||
|
|
||||||
|
Counter invariant:
|
||||||
|
``total == reverted + already_reverted + not_reversible
|
||||||
|
+ permission_denied + failed + skipped``
|
||||||
|
|
||||||
|
Frontend toasts and the ``RevertTurnButton`` summary rely on this
|
||||||
|
invariant to display "X of Y reverted, Z could not be undone" without
|
||||||
|
silently dropping ``permission_denied`` or ``skipped`` rows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Literal["ok", "partial"]
|
||||||
|
chat_turn_id: str
|
||||||
|
total: int
|
||||||
|
reverted: int
|
||||||
|
already_reverted: int
|
||||||
|
not_reversible: int
|
||||||
|
permission_denied: int = 0
|
||||||
|
failed: int = 0
|
||||||
|
skipped: int = 0
|
||||||
|
results: list[RevertTurnActionResult]
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
|
||||||
|
if outcome.status == "ok":
|
||||||
|
return "reverted"
|
||||||
|
if outcome.status == "permission_denied":
|
||||||
|
return "permission_denied"
|
||||||
|
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
|
||||||
|
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
|
||||||
|
# — they share the same UX (this row cannot be undone) and only the
|
||||||
|
# ``message`` differs.
|
||||||
|
return "not_reversible"
|
||||||
|
|
||||||
|
|
||||||
|
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
|
||||||
|
"""Return the id of an existing successful revert row, if any.
|
||||||
|
|
||||||
|
Single-action variant — kept for the post-IntegrityError lookup
|
||||||
|
path where we already know we lost a race for one specific id.
|
||||||
|
"""
|
||||||
|
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def _was_already_reverted_batch(
|
||||||
|
session: AsyncSession, *, action_ids: list[int]
|
||||||
|
) -> dict[int, int]:
|
||||||
|
"""Batch idempotency probe for the revert-turn loop.
|
||||||
|
|
||||||
|
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
|
||||||
|
(one per row in the turn) with a single ``SELECT id, reverse_of
|
||||||
|
WHERE reverse_of IN (:ids)``. The route still iterates rows in
|
||||||
|
reverse-chronological order, but the membership check is O(1) per
|
||||||
|
iteration after this query. For a turn with 30 actions that's 30
|
||||||
|
fewer round-trips through asyncpg + a smaller transaction footprint.
|
||||||
|
|
||||||
|
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
|
||||||
|
keys mean "not yet reverted" — callers should treat them as
|
||||||
|
eligible for revert.
|
||||||
|
"""
|
||||||
|
if not action_ids:
|
||||||
|
return {}
|
||||||
|
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
|
||||||
|
AgentActionLog.reverse_of.in_(action_ids)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return {
|
||||||
|
original_id: revert_id
|
||||||
|
for revert_id, original_id in result.all()
|
||||||
|
if original_id is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
|
||||||
|
response_model=RevertTurnResponse,
|
||||||
|
)
|
||||||
|
async def revert_agent_turn(
|
||||||
|
thread_id: int,
|
||||||
|
chat_turn_id: str,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
) -> RevertTurnResponse:
|
||||||
|
"""Revert every reversible action emitted during ``chat_turn_id``.
|
||||||
|
|
||||||
|
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
|
||||||
|
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
|
||||||
|
folder) unwind in the right sequence. Each action is reverted in its
|
||||||
|
own SAVEPOINT so a single failure does not poison the batch.
|
||||||
|
|
||||||
|
Partial success is intentional and returned with HTTP 200. Callers
|
||||||
|
must inspect ``results[*].status`` to find rows that need attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
flags = get_flags()
|
||||||
|
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"Revert is not available on this deployment yet. The route "
|
||||||
|
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
|
||||||
|
"enable it."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = await load_thread(session, thread_id=thread_id)
|
||||||
|
if thread is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||||
|
|
||||||
|
# Reverse-chronological so the latest mutation in the turn unwinds
|
||||||
|
# first. ``id.desc()`` is the deterministic tiebreaker for actions
|
||||||
|
# written in the same millisecond.
|
||||||
|
rows_stmt = (
|
||||||
|
select(AgentActionLog)
|
||||||
|
.where(
|
||||||
|
AgentActionLog.thread_id == thread_id,
|
||||||
|
AgentActionLog.chat_turn_id == chat_turn_id,
|
||||||
|
)
|
||||||
|
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
|
||||||
|
)
|
||||||
|
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||||
|
|
||||||
|
requester_user_id = str(user.id) if user is not None else None
|
||||||
|
results: list[RevertTurnActionResult] = []
|
||||||
|
# Counters MUST be exhaustive so the response invariant
|
||||||
|
# ``total == sum(counters)`` always holds. Frontend toasts and
|
||||||
|
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
|
||||||
|
counts: dict[str, int] = {
|
||||||
|
"reverted": 0,
|
||||||
|
"already_reverted": 0,
|
||||||
|
"not_reversible": 0,
|
||||||
|
"permission_denied": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Single batched idempotency probe replaces the previous per-row
|
||||||
|
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
|
||||||
|
# the original-action ids (skip rows that are themselves
|
||||||
|
# reverts).
|
||||||
|
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||||
|
already_reverted_map = await _was_already_reverted_batch(
|
||||||
|
session, action_ids=eligible_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in rows:
|
||||||
|
# Skip rows that ARE reverts of an earlier action — reverting a
|
||||||
|
# revert is meaningless inside a batch (the user wants to wipe
|
||||||
|
# the original effects, not chase tail).
|
||||||
|
if action.reverse_of is not None:
|
||||||
|
counts["skipped"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="skipped",
|
||||||
|
message="Row is itself a revert action; skipped.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Idempotency: surface "already_reverted" instead of failing.
|
||||||
|
existing_revert_id = already_reverted_map.get(action.id)
|
||||||
|
if existing_revert_id is not None:
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not can_revert(
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
action=action,
|
||||||
|
is_admin=False,
|
||||||
|
):
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="permission_denied",
|
||||||
|
message="You are not allowed to revert this action.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Per-row SAVEPOINT so one failed revert never poisons later
|
||||||
|
# successful ones.
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
outcome = await revert_action(
|
||||||
|
session,
|
||||||
|
action=action,
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
)
|
||||||
|
if outcome.status != "ok":
|
||||||
|
raise _OutcomeRollbackError(outcome)
|
||||||
|
except _OutcomeRollbackError as rollback:
|
||||||
|
outcome = rollback.outcome
|
||||||
|
classified = _classify_outcome(outcome)
|
||||||
|
if classified == "permission_denied":
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
else:
|
||||||
|
counts["not_reversible"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status=classified,
|
||||||
|
message=outcome.message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except IntegrityError:
|
||||||
|
# Partial unique index caught a concurrent revert that won
|
||||||
|
# the race against our pre-flight ``_was_already_reverted``
|
||||||
|
# SELECT. Look up the winner so
|
||||||
|
# we can surface its ``new_action_id`` to the client.
|
||||||
|
existing_revert_id = await _was_already_reverted(
|
||||||
|
session, action_id=action.id
|
||||||
|
)
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as err: # pragma: no cover — defensive, logged
|
||||||
|
logger.exception(
|
||||||
|
"Unexpected revert failure inside batch for action_id=%s",
|
||||||
|
action.id,
|
||||||
|
)
|
||||||
|
counts["failed"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="failed",
|
||||||
|
error=str(err) or err.__class__.__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
counts["reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="reverted",
|
||||||
|
message=outcome.message,
|
||||||
|
new_action_id=outcome.new_action_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single commit at the end — successful SAVEPOINTs above already
|
||||||
|
# released; failed ones rolled back to their savepoint. No row leaks
|
||||||
|
# across the boundary.
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception as err: # pragma: no cover — defensive
|
||||||
|
logger.exception(
|
||||||
|
"Final commit for revert-turn failed (thread=%s turn=%s)",
|
||||||
|
thread_id,
|
||||||
|
chat_turn_id,
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Internal error while finalising revert-turn batch.",
|
||||||
|
) from err
|
||||||
|
|
||||||
|
has_partial = (
|
||||||
|
counts["failed"] > 0
|
||||||
|
or counts["not_reversible"] > 0
|
||||||
|
or counts["permission_denied"] > 0
|
||||||
|
)
|
||||||
|
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
|
||||||
|
|
||||||
|
return RevertTurnResponse(
|
||||||
|
status=overall_status,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
|
total=len(rows),
|
||||||
|
reverted=counts["reverted"],
|
||||||
|
already_reverted=counts["already_reverted"],
|
||||||
|
not_reversible=counts["not_reversible"],
|
||||||
|
permission_denied=counts["permission_denied"],
|
||||||
|
failed=counts["failed"],
|
||||||
|
skipped=counts["skipped"],
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _OutcomeRollbackError(Exception):
|
||||||
|
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
|
||||||
|
|
||||||
|
``revert_action`` writes a new ``agent_action_log`` row only on the
|
||||||
|
happy path, but on the failure paths it sometimes mutates the
|
||||||
|
``DocumentRevision``/``Document`` tables before deciding the action
|
||||||
|
is not reversible. Wrapping each call in ``begin_nested`` and raising
|
||||||
|
this from the failure branch ensures we always discard partial
|
||||||
|
writes for failed rows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, outcome: RevertOutcome) -> None:
|
||||||
|
self.outcome = outcome
|
||||||
|
super().__init__(outcome.message)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
|
@ -136,6 +137,260 @@ def _resolve_filesystem_selection(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_pre_turn_checkpoint_id(
|
||||||
|
checkpoint_tuples: list,
|
||||||
|
*,
|
||||||
|
turn_id: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
|
||||||
|
|
||||||
|
``checkpoint_tuples`` arrives newest-first from
|
||||||
|
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
|
||||||
|
and remember the most recent checkpoint that does NOT belong to the
|
||||||
|
edited turn. As soon as we cross into the edited turn (a checkpoint
|
||||||
|
whose ``turn_id`` matches), we return the previously-tracked
|
||||||
|
checkpoint — that's the state immediately before ``turn_id`` began.
|
||||||
|
|
||||||
|
The naive "newest-first, return first non-matching" approach is
|
||||||
|
INCORRECT when later turns exist after ``turn_id``: their
|
||||||
|
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
|
||||||
|
returned before the real pre-turn boundary is reached.
|
||||||
|
|
||||||
|
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
|
||||||
|
``configurable`` at write time) rather than ``config["configurable"]``
|
||||||
|
so the lookup is portable across checkpointer implementations.
|
||||||
|
|
||||||
|
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
|
||||||
|
the edited turn is the very first turn of the thread). Callers fall
|
||||||
|
back to the oldest available checkpoint in that case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_pre_turn_target: str | None = None
|
||||||
|
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
|
||||||
|
metadata = getattr(cp_tuple, "metadata", None) or {}
|
||||||
|
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
|
||||||
|
if cp_turn_id == turn_id:
|
||||||
|
# Crossed into the edited turn; the previous tracked
|
||||||
|
# checkpoint is the rewind target. May be ``None`` if we hit
|
||||||
|
# the edited turn on the very first iteration.
|
||||||
|
return last_pre_turn_target
|
||||||
|
try:
|
||||||
|
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
return last_pre_turn_target
|
||||||
|
|
||||||
|
|
||||||
|
async def _revert_turns_for_regenerate(
|
||||||
|
*,
|
||||||
|
thread_id: int,
|
||||||
|
chat_turn_ids: list[str],
|
||||||
|
requester_user_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
|
||||||
|
|
||||||
|
Runs BEFORE the regenerate stream so the frontend can surface
|
||||||
|
partial-rollback feedback alongside the new assistant turn. Each
|
||||||
|
turn's actions are reverted in their own SAVEPOINTs (handled
|
||||||
|
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
|
||||||
|
failure never poisons the batch.
|
||||||
|
|
||||||
|
Sequencing inside the request: revert THEN regenerate. The
|
||||||
|
operation is NOT atomic and partial state IS surfaced — see the
|
||||||
|
plan's "Sequencing inside the request" note.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.routes.agent_revert_route import (
|
||||||
|
RevertTurnActionResult,
|
||||||
|
_classify_outcome,
|
||||||
|
_OutcomeRollbackError,
|
||||||
|
_was_already_reverted,
|
||||||
|
_was_already_reverted_batch,
|
||||||
|
)
|
||||||
|
from app.services.revert_service import (
|
||||||
|
can_revert,
|
||||||
|
revert_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated_results: list[dict] = []
|
||||||
|
# Exhaustive counters keep the response invariant
|
||||||
|
# ``total == sum(counters)`` true for ``data-revert-results``.
|
||||||
|
counts = {
|
||||||
|
"reverted": 0,
|
||||||
|
"already_reverted": 0,
|
||||||
|
"not_reversible": 0,
|
||||||
|
"permission_denied": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Local import keeps the route module's existing imports tidy and
|
||||||
|
# avoids a circular dependency at module-load time.
|
||||||
|
from app.db import AgentActionLog as _AgentActionLog
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
for chat_turn_id in chat_turn_ids:
|
||||||
|
rows_stmt = (
|
||||||
|
select(_AgentActionLog)
|
||||||
|
.where(
|
||||||
|
_AgentActionLog.thread_id == thread_id,
|
||||||
|
_AgentActionLog.chat_turn_id == chat_turn_id,
|
||||||
|
)
|
||||||
|
.order_by(
|
||||||
|
_AgentActionLog.created_at.desc(),
|
||||||
|
_AgentActionLog.id.desc(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||||
|
|
||||||
|
# Batch idempotency probe across the turn (single SELECT
|
||||||
|
# instead of one per row).
|
||||||
|
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||||
|
already_reverted_map = await _was_already_reverted_batch(
|
||||||
|
session, action_ids=eligible_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in rows:
|
||||||
|
if action.reverse_of is not None:
|
||||||
|
counts["skipped"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="skipped",
|
||||||
|
message="Row is itself a revert action; skipped.",
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing_revert_id = already_reverted_map.get(action.id)
|
||||||
|
if existing_revert_id is not None:
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not can_revert(
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
action=action,
|
||||||
|
is_admin=False,
|
||||||
|
):
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="permission_denied",
|
||||||
|
message="You are not allowed to revert this action.",
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
outcome = await revert_action(
|
||||||
|
session,
|
||||||
|
action=action,
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
)
|
||||||
|
if outcome.status != "ok":
|
||||||
|
raise _OutcomeRollbackError(outcome)
|
||||||
|
except _OutcomeRollbackError as rollback:
|
||||||
|
outcome = rollback.outcome
|
||||||
|
classified = _classify_outcome(outcome)
|
||||||
|
if classified == "permission_denied":
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
else:
|
||||||
|
counts["not_reversible"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status=classified,
|
||||||
|
message=outcome.message,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except IntegrityError:
|
||||||
|
# Concurrent revert won the race against the
|
||||||
|
# pre-flight ``_was_already_reverted`` SELECT.
|
||||||
|
# Surface the winning revert id so the client can
|
||||||
|
# treat this as a successful idempotent op.
|
||||||
|
existing_revert_id = await _was_already_reverted(
|
||||||
|
session, action_id=action.id
|
||||||
|
)
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as err: # pragma: no cover — defensive
|
||||||
|
_logger.exception(
|
||||||
|
"Unexpected revert failure during regenerate batch "
|
||||||
|
"for action_id=%s",
|
||||||
|
action.id,
|
||||||
|
)
|
||||||
|
counts["failed"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="failed",
|
||||||
|
error=str(err) or err.__class__.__name__,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
counts["reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="reverted",
|
||||||
|
message=outcome.message,
|
||||||
|
new_action_id=outcome.new_action_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(
|
||||||
|
"[regenerate-revert] Final commit failed; rolling back batch."
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
has_partial = (
|
||||||
|
counts["failed"] > 0
|
||||||
|
or counts["not_reversible"] > 0
|
||||||
|
or counts["permission_denied"] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "partial" if has_partial else "ok",
|
||||||
|
"chat_turn_ids": chat_turn_ids,
|
||||||
|
"total": len(aggregated_results),
|
||||||
|
"reverted": counts["reverted"],
|
||||||
|
"already_reverted": counts["already_reverted"],
|
||||||
|
"not_reversible": counts["not_reversible"],
|
||||||
|
"permission_denied": counts["permission_denied"],
|
||||||
|
"failed": counts["failed"],
|
||||||
|
"skipped": counts["skipped"],
|
||||||
|
"results": aggregated_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _try_delete_sandbox(thread_id: int) -> None:
|
def _try_delete_sandbox(thread_id: int) -> None:
|
||||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||||
from app.agents.new_chat.sandbox import (
|
from app.agents.new_chat.sandbox import (
|
||||||
|
|
@ -574,6 +829,7 @@ async def get_thread_messages(
|
||||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
||||||
if msg.token_usage
|
if msg.token_usage
|
||||||
else None,
|
else None,
|
||||||
|
turn_id=msg.turn_id,
|
||||||
)
|
)
|
||||||
for msg in db_messages
|
for msg in db_messages
|
||||||
]
|
]
|
||||||
|
|
@ -1006,12 +1262,24 @@ async def append_message(
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
# Create message
|
# Create message. ``turn_id`` is the per-turn correlation id from
|
||||||
|
# ``configurable.turn_id`` (added in migration 136) — when the
|
||||||
|
# client streams it back to ``appendMessage``, we persist it so
|
||||||
|
# C1's edit-from-arbitrary-position can later map this message
|
||||||
|
# back to the LangGraph checkpoint that produced its turn.
|
||||||
|
raw_turn_id = raw_body.get("turn_id")
|
||||||
|
turn_id_value = (
|
||||||
|
str(raw_turn_id).strip()
|
||||||
|
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
db_message = NewChatMessage(
|
db_message = NewChatMessage(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
role=message_role,
|
role=message_role,
|
||||||
content=content,
|
content=content,
|
||||||
author_id=user.id,
|
author_id=user.id,
|
||||||
|
turn_id=turn_id_value,
|
||||||
)
|
)
|
||||||
session.add(db_message)
|
session.add(db_message)
|
||||||
|
|
||||||
|
|
@ -1050,6 +1318,7 @@ async def append_message(
|
||||||
created_at=db_message.created_at,
|
created_at=db_message.created_at,
|
||||||
author_id=db_message.author_id,
|
author_id=db_message.author_id,
|
||||||
token_usage=None,
|
token_usage=None,
|
||||||
|
turn_id=db_message.turn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -1373,43 +1642,123 @@ async def regenerate_response(
|
||||||
user_query_to_use = request.user_query
|
user_query_to_use = request.user_query
|
||||||
regenerate_image_urls: list[str] = []
|
regenerate_image_urls: list[str] = []
|
||||||
|
|
||||||
# Look through checkpoints to find the right one
|
# ---------------------------------------------------------------
|
||||||
# We want to find the checkpoint just before the last HumanMessage
|
# Edit-from-arbitrary-position. When the client passes
|
||||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
# ``from_message_id`` we look up its persisted ``turn_id`` (added
|
||||||
# Access the checkpoint's channel_values which contains "messages"
|
# in migration 136) and pick the checkpoint immediately before
|
||||||
checkpoint_data = cp_tuple.checkpoint
|
# that turn started.
|
||||||
channel_values = checkpoint_data.get("channel_values", {})
|
#
|
||||||
state_messages = channel_values.get("messages", [])
|
# Legacy graceful-degradation contract:
|
||||||
|
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
|
||||||
|
# Returning 400 in that case is the wrong UX — the user is
|
||||||
|
# editing an old message in an existing thread and just wants
|
||||||
|
# it to work. We instead skip the checkpoint rewind (the
|
||||||
|
# stream falls back to the latest state) and skip the revert
|
||||||
|
# pass (no chat_turn_id available to walk). Deletion still
|
||||||
|
# uses ``created_at``, so the messages-after-cursor slice is
|
||||||
|
# correct on both legacy and post-136 rows.
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
from_message_turn_id: str | None = None
|
||||||
|
from_message_created_at: datetime | None = None
|
||||||
|
legacy_from_message: bool = False
|
||||||
|
if request.from_message_id is not None:
|
||||||
|
from_msg_row = await session.execute(
|
||||||
|
select(NewChatMessage).filter(
|
||||||
|
NewChatMessage.id == request.from_message_id,
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
from_msg = from_msg_row.scalars().first()
|
||||||
|
if from_msg is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="from_message_id not found in this thread.",
|
||||||
|
)
|
||||||
|
from_message_created_at = from_msg.created_at
|
||||||
|
if not from_msg.turn_id:
|
||||||
|
# Legacy row — surface the degradation in logs but let
|
||||||
|
# the request proceed with the slice-based delete and a
|
||||||
|
# cold-start checkpoint.
|
||||||
|
legacy_from_message = True
|
||||||
|
_logger.warning(
|
||||||
|
"[regenerate] from_message_id=%s on thread=%s has no "
|
||||||
|
"turn_id (legacy row pre-migration-136). Falling back "
|
||||||
|
"to slice-based delete without checkpoint rewind. "
|
||||||
|
"revert_actions=%s will be ignored.",
|
||||||
|
request.from_message_id,
|
||||||
|
thread_id,
|
||||||
|
request.revert_actions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from_message_turn_id = from_msg.turn_id
|
||||||
|
|
||||||
if state_messages:
|
# Walk oldest-to-newest and pick the LAST checkpoint whose
|
||||||
last_msg = state_messages[-1]
|
# ``turn_id`` differs from the edited turn — that's the state
|
||||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
# immediately before this turn started running. We read from
|
||||||
# This means we're at a state before the user's last message
|
# ``metadata`` (the durable surface) rather than
|
||||||
if not isinstance(last_msg, HumanMessage):
|
# ``config["configurable"]`` so the lookup works across
|
||||||
# If no new user_query provided (reload), extract from a later checkpoint
|
# checkpointer implementations.
|
||||||
if user_query_to_use is None and i > 0:
|
target_checkpoint_id = _find_pre_turn_checkpoint_id(
|
||||||
# Get the user query from a more recent checkpoint
|
checkpoint_tuples,
|
||||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
turn_id=from_message_turn_id,
|
||||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
)
|
||||||
prev_channel_values = prev_checkpoint_data.get(
|
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
|
||||||
"channel_values", {}
|
# Fall back to the oldest checkpoint — better than
|
||||||
)
|
# 400ing when the agent didn't checkpoint pre-turn
|
||||||
prev_messages = prev_channel_values.get("messages", [])
|
# (e.g. very first turn of the thread).
|
||||||
for msg in reversed(prev_messages):
|
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
q, imgs = split_langchain_human_content(msg.content)
|
|
||||||
user_query_to_use = q
|
|
||||||
regenerate_image_urls = imgs
|
|
||||||
break
|
|
||||||
if user_query_to_use is not None and (
|
|
||||||
str(user_query_to_use).strip() or regenerate_image_urls
|
|
||||||
):
|
|
||||||
break
|
|
||||||
|
|
||||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
|
||||||
"checkpoint_id"
|
"checkpoint_id"
|
||||||
]
|
]
|
||||||
break
|
|
||||||
|
# Look through checkpoints to find the right one
|
||||||
|
# We want to find the checkpoint just before the last HumanMessage.
|
||||||
|
# We enter this branch when:
|
||||||
|
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
|
||||||
|
# * the client pinned ``from_message_id`` but the row is a
|
||||||
|
# legacy pre-migration-136 row with no ``turn_id`` (we
|
||||||
|
# downgraded to the same heuristic as a regular reload).
|
||||||
|
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
|
||||||
|
# — that's the C1 happy path and the heuristic below would just
|
||||||
|
# re-derive a worse target.
|
||||||
|
if request.from_message_id is None or legacy_from_message:
|
||||||
|
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||||
|
# Access the checkpoint's channel_values which contains "messages"
|
||||||
|
checkpoint_data = cp_tuple.checkpoint
|
||||||
|
channel_values = checkpoint_data.get("channel_values", {})
|
||||||
|
state_messages = channel_values.get("messages", [])
|
||||||
|
|
||||||
|
if state_messages:
|
||||||
|
last_msg = state_messages[-1]
|
||||||
|
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||||
|
# This means we're at a state before the user's last message
|
||||||
|
if not isinstance(last_msg, HumanMessage):
|
||||||
|
# If no new user_query provided (reload), extract from a later checkpoint
|
||||||
|
if user_query_to_use is None and i > 0:
|
||||||
|
# Get the user query from a more recent checkpoint
|
||||||
|
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||||
|
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||||
|
prev_channel_values = prev_checkpoint_data.get(
|
||||||
|
"channel_values", {}
|
||||||
|
)
|
||||||
|
prev_messages = prev_channel_values.get("messages", [])
|
||||||
|
for msg in reversed(prev_messages):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
q, imgs = split_langchain_human_content(
|
||||||
|
msg.content
|
||||||
|
)
|
||||||
|
user_query_to_use = q
|
||||||
|
regenerate_image_urls = imgs
|
||||||
|
break
|
||||||
|
if user_query_to_use is not None and (
|
||||||
|
str(user_query_to_use).strip()
|
||||||
|
or regenerate_image_urls
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||||
|
"checkpoint_id"
|
||||||
|
]
|
||||||
|
break
|
||||||
|
|
||||||
# If we couldn't find a good checkpoint, try alternative approaches
|
# If we couldn't find a good checkpoint, try alternative approaches
|
||||||
if target_checkpoint_id is None and checkpoint_tuples:
|
if target_checkpoint_id is None and checkpoint_tuples:
|
||||||
|
|
@ -1472,18 +1821,51 @@ async def regenerate_response(
|
||||||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the last two messages to delete AFTER streaming succeeds
|
# Get the messages to delete AFTER streaming succeeds.
|
||||||
# This prevents data loss if streaming fails
|
# This prevents data loss if streaming fails.
|
||||||
last_messages_result = await session.execute(
|
#
|
||||||
select(NewChatMessage)
|
# When ``from_message_id`` is set we slice from that message
|
||||||
.filter(NewChatMessage.thread_id == thread_id)
|
# forward (using ``created_at`` so we also catch any tool/system
|
||||||
.order_by(NewChatMessage.created_at.desc())
|
# messages persisted into the same turn). Otherwise
|
||||||
.limit(2)
|
# we keep the legacy "last 2 messages" rewind.
|
||||||
)
|
if request.from_message_id is not None and from_message_created_at is not None:
|
||||||
|
last_messages_result = await session.execute(
|
||||||
|
select(NewChatMessage)
|
||||||
|
.filter(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.created_at >= from_message_created_at,
|
||||||
|
)
|
||||||
|
.order_by(NewChatMessage.created_at.desc())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_messages_result = await session.execute(
|
||||||
|
select(NewChatMessage)
|
||||||
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
|
.order_by(NewChatMessage.created_at.desc())
|
||||||
|
.limit(2)
|
||||||
|
)
|
||||||
messages_to_delete = list(last_messages_result.scalars().all())
|
messages_to_delete = list(last_messages_result.scalars().all())
|
||||||
|
|
||||||
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
||||||
|
|
||||||
|
# When revert_actions is requested, collect the set of
|
||||||
|
# ``chat_turn_id``s present in the slice we're about to delete.
|
||||||
|
# Each one will be reverted (best-effort) BEFORE the regenerate
|
||||||
|
# stream begins. Legacy rows have ``turn_id=None`` and silently
|
||||||
|
# contribute nothing — we already logged the degradation above.
|
||||||
|
revert_turn_ids: list[str] = []
|
||||||
|
if (
|
||||||
|
request.revert_actions
|
||||||
|
and request.from_message_id is not None
|
||||||
|
and not legacy_from_message
|
||||||
|
):
|
||||||
|
seen_turns: set[str] = set()
|
||||||
|
for msg in messages_to_delete:
|
||||||
|
tid = msg.turn_id
|
||||||
|
if tid and tid not in seen_turns:
|
||||||
|
seen_turns.add(tid)
|
||||||
|
revert_turn_ids.append(tid)
|
||||||
|
|
||||||
# Get search space for LLM config
|
# Get search space for LLM config
|
||||||
search_space_result = await session.execute(
|
search_space_result = await session.execute(
|
||||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||||
|
|
@ -1507,6 +1889,24 @@ async def regenerate_response(
|
||||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||||
async def stream_with_cleanup():
|
async def stream_with_cleanup():
|
||||||
streaming_completed = False
|
streaming_completed = False
|
||||||
|
# Best-effort revert pass BEFORE the regenerate stream begins.
|
||||||
|
# Each turn is reverted independently (per-row SAVEPOINTs
|
||||||
|
# inside the route helper) and the per-action results are surfaced
|
||||||
|
# on a single ``data-revert-results`` SSE event so the frontend
|
||||||
|
# can render any failed rows alongside the new turn. Failures here
|
||||||
|
# do NOT abort the regeneration — partial rollback is documented
|
||||||
|
# behaviour.
|
||||||
|
if revert_turn_ids:
|
||||||
|
revert_results = await _revert_turns_for_regenerate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
chat_turn_ids=revert_turn_ids,
|
||||||
|
requester_user_id=str(user.id),
|
||||||
|
)
|
||||||
|
envelope = {
|
||||||
|
"type": "data-revert-results",
|
||||||
|
"data": revert_results,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_new_chat(
|
async for chunk in stream_new_chat(
|
||||||
user_query=str(user_query_to_use),
|
user_query=str(user_query_to_use),
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||||
author_display_name: str | None = None
|
author_display_name: str | None = None
|
||||||
author_avatar_url: str | None = None
|
author_avatar_url: str | None = None
|
||||||
token_usage: TokenUsageSummary | None = None
|
token_usage: TokenUsageSummary | None = None
|
||||||
|
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
|
||||||
|
# ``configurable.turn_id`` at streaming time. Nullable because
|
||||||
|
# legacy rows predate the column; clients should treat NULL as
|
||||||
|
# "edit-from-this-message is unavailable".
|
||||||
|
turn_id: str | None = None
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
|
||||||
|
|
||||||
For edit, optional user_images (when not None) replaces image URLs resolved from
|
For edit, optional user_images (when not None) replaces image URLs resolved from
|
||||||
checkpoint/DB so the client can send the full user turn (text and/or images).
|
checkpoint/DB so the client can send the full user turn (text and/or images).
|
||||||
|
|
||||||
|
Edit-from-arbitrary-position. When ``from_message_id`` is provided
|
||||||
|
the route slices conversation history starting at that message (instead of
|
||||||
|
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
|
||||||
|
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
|
||||||
|
optionally reverts every reversible action emitted in turns at or after
|
||||||
|
``from_message_id``. The revert step is best-effort and runs BEFORE the
|
||||||
|
regenerate stream — partial failures are surfaced via SSE
|
||||||
|
``data-revert-results`` and do not abort the regeneration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
||||||
)
|
)
|
||||||
|
from_message_id: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Message id to rewind to. When set, history is sliced "
|
||||||
|
"from this message forward and the LangGraph checkpoint is "
|
||||||
|
"rewound to the state immediately preceding this turn. Legacy "
|
||||||
|
"rows that predate migration 136 have ``turn_id=None`` and "
|
||||||
|
"still process — the route logs a warning, skips the "
|
||||||
|
"checkpoint rewind, and ignores ``revert_actions`` (no "
|
||||||
|
"chat_turn_id available to walk)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
revert_actions: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"When true, every reversible action emitted at or "
|
||||||
|
"after ``from_message_id`` is reverted before the regenerate "
|
||||||
|
"stream begins. Per-action results are surfaced via the "
|
||||||
|
"``data-revert-results`` SSE event. Partial failures DO NOT "
|
||||||
|
"abort the regeneration."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_regenerate_user_images(self) -> Self:
|
def _validate_regenerate_user_images(self) -> Self:
|
||||||
|
|
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
|
||||||
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_revert_actions_requires_from_message(self) -> Self:
|
||||||
|
if self.revert_actions and self.from_message_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"revert_actions requires from_message_id; specify which message to rewind to"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Agent Tools Schemas
|
# Agent Tools Schemas
|
||||||
|
|
|
||||||
|
|
@ -584,13 +584,24 @@ class VercelStreamingService:
|
||||||
# Tool Parts
|
# Tool Parts
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
|
def format_tool_input_start(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format the start of tool input streaming.
|
Format the start of tool input streaming.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The unique tool call identifier
|
tool_call_id: The unique tool call identifier (synthetic, derived
|
||||||
tool_name: The name of the tool being called
|
from LangGraph ``run_id`` so the frontend has a stable card id).
|
||||||
|
tool_name: The name of the tool being called.
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id``. When set, surfaces as
|
||||||
|
``langchainToolCallId`` so the frontend can join this card
|
||||||
|
to the action-log row written by ``ActionLogMiddleware``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool input start part
|
str: SSE formatted tool input start part
|
||||||
|
|
@ -598,13 +609,14 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
"type": "tool-input-start",
|
||||||
"type": "tool-input-start",
|
"toolCallId": tool_call_id,
|
||||||
"toolCallId": tool_call_id,
|
"toolName": tool_name,
|
||||||
"toolName": tool_name,
|
}
|
||||||
}
|
if langchain_tool_call_id:
|
||||||
)
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -629,7 +641,12 @@ class VercelStreamingService:
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_tool_input_available(
|
def format_tool_input_available(
|
||||||
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format the completion of tool input.
|
Format the completion of tool input.
|
||||||
|
|
@ -638,6 +655,8 @@ class VercelStreamingService:
|
||||||
tool_call_id: The tool call identifier
|
tool_call_id: The tool call identifier
|
||||||
tool_name: The name of the tool
|
tool_name: The name of the tool
|
||||||
input_data: The complete tool input parameters
|
input_data: The complete tool input parameters
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id`` (see ``format_tool_input_start``).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool input available part
|
str: SSE formatted tool input available part
|
||||||
|
|
@ -645,22 +664,34 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
"type": "tool-input-available",
|
||||||
"type": "tool-input-available",
|
"toolCallId": tool_call_id,
|
||||||
"toolCallId": tool_call_id,
|
"toolName": tool_name,
|
||||||
"toolName": tool_name,
|
"input": input_data,
|
||||||
"input": input_data,
|
}
|
||||||
}
|
if langchain_tool_call_id:
|
||||||
)
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
|
def format_tool_output_available(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
output: Any,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format tool execution output.
|
Format tool execution output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The tool call identifier
|
tool_call_id: The tool call identifier
|
||||||
output: The tool execution result
|
output: The tool execution result
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
|
||||||
|
When set, the frontend can backfill any card whose
|
||||||
|
``langchainToolCallId`` was not yet known at
|
||||||
|
``tool-input-start`` time.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool output available part
|
str: SSE formatted tool output available part
|
||||||
|
|
@ -668,13 +699,14 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
"type": "tool-output-available",
|
||||||
"type": "tool-output-available",
|
"toolCallId": tool_call_id,
|
||||||
"toolCallId": tool_call_id,
|
"output": output,
|
||||||
"output": output,
|
}
|
||||||
}
|
if langchain_tool_call_id:
|
||||||
)
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Step Parts
|
# Step Parts
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
|
||||||
|
|
||||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||||
written before the original mutation.
|
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
|
||||||
|
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
|
||||||
|
that was created; everything else is an in-place restore.
|
||||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||||
the inverse tool through the agent's normal permission stack (NOT
|
the inverse tool through the agent's normal permission stack (NOT
|
||||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||||
|
|
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
|
||||||
A successful revert appends a NEW row to ``agent_action_log`` with
|
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||||
``reverse_of=<original_action_id>`` and the requesting user's
|
``reverse_of=<original_action_id>`` and the requesting user's
|
||||||
``user_id``, preserving an auditable chain.
|
``user_id``, preserving an auditable chain.
|
||||||
|
|
||||||
|
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
|
||||||
|
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
|
||||||
|
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
|
||||||
|
same trap waiting to happen).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -25,17 +32,31 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
safe_filename,
|
||||||
|
safe_folder_segment,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
AgentActionLog,
|
AgentActionLog,
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
DocumentRevision,
|
DocumentRevision,
|
||||||
|
DocumentType,
|
||||||
|
Folder,
|
||||||
FolderRevision,
|
FolderRevision,
|
||||||
NewChatThread,
|
NewChatThread,
|
||||||
)
|
)
|
||||||
|
from app.utils.document_converters import (
|
||||||
|
embed_texts,
|
||||||
|
generate_content_hash,
|
||||||
|
generate_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -110,14 +131,244 @@ def can_revert(
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Revert paths
|
# Helper: reconstruct virtual path from a snapshot
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _virtual_path_from_snapshot(
|
||||||
|
session: AsyncSession,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> str | None:
|
||||||
|
"""Reconstruct the virtual_path the document was at before mutation.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1. ``metadata_before["virtual_path"]`` — written by every snapshot
|
||||||
|
helper since this PR.
|
||||||
|
2. Compose ``"<folder_path>/<title_before>"`` from
|
||||||
|
``folder_id_before`` + ``title_before``. Walks the folder chain via
|
||||||
|
``parent_id``.
|
||||||
|
"""
|
||||||
|
metadata = revision.metadata_before or {}
|
||||||
|
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
|
||||||
|
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
title = revision.title_before
|
||||||
|
if not isinstance(title, str) or not title:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
cursor: int | None = revision.folder_id_before
|
||||||
|
visited: set[int] = set()
|
||||||
|
while cursor is not None and cursor not in visited:
|
||||||
|
visited.add(cursor)
|
||||||
|
folder = await session.get(Folder, cursor)
|
||||||
|
if folder is None:
|
||||||
|
return None
|
||||||
|
parts.append(safe_folder_segment(str(folder.name or "")))
|
||||||
|
cursor = folder.parent_id
|
||||||
|
parts.reverse()
|
||||||
|
|
||||||
|
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||||
|
filename = safe_filename(title)
|
||||||
|
return f"{base}/{filename}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document revision restore (write/edit/move/rm)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _set_field(target: Any, field: str, value: Any) -> None:
|
||||||
|
if value is not None:
|
||||||
|
setattr(target, field, value)
|
||||||
|
|
||||||
|
|
||||||
|
async def _restore_in_place_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Apply an in-place restore to an existing :class:`Document`."""
|
||||||
|
if revision.document_id is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Original document was hard-deleted; in-place restore is not possible."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
doc = await session.get(Document, revision.document_id)
|
||||||
|
if doc is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message="Original document has been deleted; revert cannot proceed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
_set_field(doc, "content", revision.content_before)
|
||||||
|
_set_field(doc, "source_markdown", revision.content_before)
|
||||||
|
_set_field(doc, "title", revision.title_before)
|
||||||
|
_set_field(doc, "folder_id", revision.folder_id_before)
|
||||||
|
metadata_before = revision.metadata_before or {}
|
||||||
|
if isinstance(metadata_before, dict) and metadata_before:
|
||||||
|
doc.document_metadata = dict(metadata_before)
|
||||||
|
|
||||||
|
if isinstance(revision.content_before, str):
|
||||||
|
doc.content_hash = generate_content_hash(
|
||||||
|
revision.content_before, doc.search_space_id
|
||||||
|
)
|
||||||
|
|
||||||
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||||
|
if virtual_path:
|
||||||
|
doc.unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
doc.search_space_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_before = revision.chunks_before
|
||||||
|
if isinstance(chunks_before, list):
|
||||||
|
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
|
||||||
|
chunk_texts = [
|
||||||
|
str(c.get("content"))
|
||||||
|
for c in chunks_before
|
||||||
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
|
]
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(
|
||||||
|
chunk_texts, chunk_embeddings, strict=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if isinstance(revision.content_before, str):
|
||||||
|
doc.embedding = embed_texts([revision.content_before])[0]
|
||||||
|
|
||||||
|
doc.updated_at = datetime.now(UTC)
|
||||||
|
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _reinsert_document_from_revision(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
|
||||||
|
if not isinstance(revision.title_before, str) or not revision.title_before:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks title_before; cannot recreate document.",
|
||||||
|
)
|
||||||
|
if not isinstance(revision.content_before, str):
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks content_before; cannot recreate document.",
|
||||||
|
)
|
||||||
|
|
||||||
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||||
|
if not virtual_path:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message=(
|
||||||
|
"Snapshot is missing both metadata_before['virtual_path'] AND "
|
||||||
|
"a resolvable (folder_id_before, title_before) pair."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
search_space_id = revision.search_space_id
|
||||||
|
unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
collision = await session.execute(
|
||||||
|
select(Document.id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_identifier_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if collision.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
f"A document already exists at '{virtual_path}'; revert would "
|
||||||
|
"collide. Move the live doc out of the way first."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = revision.metadata_before or {}
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
metadata = {}
|
||||||
|
metadata = dict(metadata)
|
||||||
|
metadata["virtual_path"] = virtual_path
|
||||||
|
|
||||||
|
content = revision.content_before
|
||||||
|
new_doc = Document(
|
||||||
|
title=revision.title_before,
|
||||||
|
document_type=DocumentType.NOTE,
|
||||||
|
document_metadata=metadata,
|
||||||
|
content=content,
|
||||||
|
content_hash=generate_content_hash(content, search_space_id),
|
||||||
|
unique_identifier_hash=unique_identifier_hash,
|
||||||
|
source_markdown=content,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
folder_id=revision.folder_id_before,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(new_doc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
new_doc.embedding = embed_texts([content])[0]
|
||||||
|
chunk_texts = []
|
||||||
|
chunks_before = revision.chunks_before
|
||||||
|
if isinstance(chunks_before, list):
|
||||||
|
chunk_texts = [
|
||||||
|
str(c.get("content"))
|
||||||
|
for c in chunks_before
|
||||||
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
|
]
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Repoint the snapshot at the recreated row so a follow-up revert of
|
||||||
|
# the same row works as expected.
|
||||||
|
revision.document_id = new_doc.id
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_created_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
|
||||||
|
if revision.document_id is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="No live row to delete (already removed elsewhere).",
|
||||||
|
)
|
||||||
|
await session.execute(delete(Document).where(Document.id == revision.document_id))
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="Deleted the document that was created by this action.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _restore_document_revision(
|
async def _restore_document_revision(
|
||||||
session: AsyncSession, *, action: AgentActionLog
|
session: AsyncSession, *, action: AgentActionLog
|
||||||
) -> RevertOutcome:
|
) -> RevertOutcome:
|
||||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
"""Dispatch document-level revert based on ``action.tool_name``."""
|
||||||
stmt = (
|
stmt = (
|
||||||
select(DocumentRevision)
|
select(DocumentRevision)
|
||||||
.where(DocumentRevision.agent_action_id == action.id)
|
.where(DocumentRevision.agent_action_id == action.id)
|
||||||
|
|
@ -132,23 +383,111 @@ async def _restore_document_revision(
|
||||||
message="No document_revisions row tied to this action.",
|
message="No document_revisions row tied to this action.",
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.db import Document # late import to avoid cycles at module load
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
doc = await session.get(Document, revision.document_id)
|
if tool_name == "rm":
|
||||||
if doc is None:
|
return await _reinsert_document_from_revision(session, revision=revision)
|
||||||
|
|
||||||
|
if tool_name == "write_file" and revision.content_before is None:
|
||||||
|
return await _delete_created_document(session, revision=revision)
|
||||||
|
|
||||||
|
return await _restore_in_place_document(session, revision=revision)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Folder revision restore (mkdir/rmdir/rename/move)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _restore_in_place_folder(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: FolderRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
if revision.folder_id is None:
|
||||||
return RevertOutcome(
|
return RevertOutcome(
|
||||||
status="tool_unavailable",
|
status="tool_unavailable",
|
||||||
message="Original document has been deleted; revert cannot proceed.",
|
message="Original folder was hard-deleted; in-place restore is impossible.",
|
||||||
|
)
|
||||||
|
folder = await session.get(Folder, revision.folder_id)
|
||||||
|
if folder is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message="Original folder has been deleted; revert cannot proceed.",
|
||||||
|
)
|
||||||
|
_set_field(folder, "name", revision.name_before)
|
||||||
|
_set_field(folder, "parent_id", revision.parent_id_before)
|
||||||
|
_set_field(folder, "position", revision.position_before)
|
||||||
|
folder.updated_at = datetime.now(UTC)
|
||||||
|
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _reinsert_folder_from_revision(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: FolderRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
if not isinstance(revision.name_before, str) or not revision.name_before:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks name_before; cannot recreate folder.",
|
||||||
|
)
|
||||||
|
new_folder = Folder(
|
||||||
|
name=revision.name_before,
|
||||||
|
parent_id=revision.parent_id_before,
|
||||||
|
position=revision.position_before,
|
||||||
|
search_space_id=revision.search_space_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(new_folder)
|
||||||
|
await session.flush()
|
||||||
|
revision.folder_id = new_folder.id
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_created_folder(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: FolderRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
if revision.folder_id is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="No live folder row to delete (already removed elsewhere).",
|
||||||
|
)
|
||||||
|
folder_id = revision.folder_id
|
||||||
|
|
||||||
|
has_doc = await session.execute(
|
||||||
|
select(Document.id).where(Document.folder_id == folder_id).limit(1)
|
||||||
|
)
|
||||||
|
if has_doc.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Folder is no longer empty (documents have been added since "
|
||||||
|
"mkdir); cannot revert."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
has_child = await session.execute(
|
||||||
|
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
|
||||||
|
)
|
||||||
|
if has_child.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Folder is no longer empty (sub-folders have been added "
|
||||||
|
"since mkdir); cannot revert."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision.content_before is not None:
|
await session.execute(delete(Folder).where(Folder.id == folder_id))
|
||||||
doc.content = revision.content_before
|
return RevertOutcome(
|
||||||
if revision.title_before is not None:
|
status="ok",
|
||||||
doc.title = revision.title_before
|
message="Deleted the folder that was created by this action.",
|
||||||
if revision.folder_id_before is not None:
|
)
|
||||||
doc.folder_id = revision.folder_id_before
|
|
||||||
doc.updated_at = datetime.now(UTC)
|
|
||||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
|
||||||
|
|
||||||
|
|
||||||
async def _restore_folder_revision(
|
async def _restore_folder_revision(
|
||||||
|
|
@ -168,41 +507,44 @@ async def _restore_folder_revision(
|
||||||
message="No folder_revisions row tied to this action.",
|
message="No folder_revisions row tied to this action.",
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.db import Folder
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
folder = await session.get(Folder, revision.folder_id)
|
if tool_name == "rmdir":
|
||||||
if folder is None:
|
return await _reinsert_folder_from_revision(session, revision=revision)
|
||||||
return RevertOutcome(
|
|
||||||
status="tool_unavailable",
|
|
||||||
message="Original folder has been deleted; revert cannot proceed.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if revision.name_before is not None:
|
if tool_name == "mkdir":
|
||||||
folder.name = revision.name_before
|
return await _delete_created_folder(session, revision=revision)
|
||||||
if revision.parent_id_before is not None:
|
|
||||||
folder.parent_id = revision.parent_id_before
|
return await _restore_in_place_folder(session, revision=revision)
|
||||||
if revision.position_before is not None:
|
|
||||||
folder.position = revision.position_before
|
|
||||||
folder.updated_at = datetime.now(UTC)
|
|
||||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
|
||||||
|
|
||||||
|
|
||||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
# ---------------------------------------------------------------------------
|
||||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
# Dispatch
|
||||||
# this module's control flow.
|
# ---------------------------------------------------------------------------
|
||||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
#
|
||||||
"edit_file",
|
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
|
||||||
"write_file",
|
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
|
||||||
"update_memory",
|
# ``delete_note``/``delete_folder``.
|
||||||
"create_note",
|
|
||||||
"update_note",
|
_DOC_TOOLS: frozenset[str] = frozenset(
|
||||||
"delete_note",
|
{
|
||||||
|
"edit_file",
|
||||||
|
"write_file",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
|
"update_memory",
|
||||||
|
"create_note",
|
||||||
|
"update_note",
|
||||||
|
"delete_note",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
_FOLDER_TOOLS: frozenset[str] = frozenset(
|
||||||
"mkdir",
|
{
|
||||||
"move_file",
|
"mkdir",
|
||||||
"rename_folder",
|
"rmdir",
|
||||||
"delete_folder",
|
"rename_folder",
|
||||||
|
"delete_folder",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -220,9 +562,9 @@ async def revert_action(
|
||||||
"""
|
"""
|
||||||
tool_name = (action.tool_name or "").lower()
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
if tool_name in _DOC_TOOLS:
|
||||||
outcome = await _restore_document_revision(session, action=action)
|
outcome = await _restore_document_revision(session, action=action)
|
||||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
elif tool_name in _FOLDER_TOOLS:
|
||||||
outcome = await _restore_folder_revision(session, action=action)
|
outcome = await _restore_folder_revision(session, action=action)
|
||||||
elif action.reverse_descriptor:
|
elif action.reverse_descriptor:
|
||||||
# Connector-owned reversibles run through the normal permission
|
# Connector-owned reversibles run through the normal permission
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.new_chat.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
|
@ -70,6 +71,91 @@ _background_tasks: set[asyncio.Task] = set()
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
|
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||||
|
|
||||||
|
Returns a dict with three keys:
|
||||||
|
|
||||||
|
* ``text`` — concatenated string content (empty string if the chunk
|
||||||
|
contributes none).
|
||||||
|
* ``reasoning`` — concatenated reasoning content (empty string if the
|
||||||
|
chunk contributes none).
|
||||||
|
* ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk``
|
||||||
|
dicts surfaced from either the typed-block list or the
|
||||||
|
``tool_call_chunks`` attribute.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
``AIMessageChunk.content`` can be:
|
||||||
|
|
||||||
|
* a ``str`` (most providers), or
|
||||||
|
* a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
|
||||||
|
'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
|
||||||
|
Anthropic, Bedrock, and several reasoning configurations.
|
||||||
|
|
||||||
|
Reasoning may also live under
|
||||||
|
``chunk.additional_kwargs['reasoning_content']`` (some providers
|
||||||
|
surface it that way instead of as a typed block). Tool-call chunks
|
||||||
|
may live under ``chunk.tool_call_chunks`` even when ``content`` is a
|
||||||
|
plain string.
|
||||||
|
|
||||||
|
Earlier versions only handled the ``isinstance(content, str)`` branch
|
||||||
|
and silently dropped reasoning blocks + tool-call chunks emitted by
|
||||||
|
LangChain ``AIMessageChunk``s.
|
||||||
|
"""
|
||||||
|
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
|
||||||
|
if chunk is None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
content = getattr(chunk, "content", None)
|
||||||
|
if isinstance(content, str):
|
||||||
|
if content:
|
||||||
|
out["text"] = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
reasoning_parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
block_type = block.get("type")
|
||||||
|
if block_type == "text":
|
||||||
|
value = block.get("text") or block.get("content") or ""
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
text_parts.append(value)
|
||||||
|
elif block_type == "reasoning":
|
||||||
|
value = (
|
||||||
|
block.get("reasoning")
|
||||||
|
or block.get("text")
|
||||||
|
or block.get("content")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
reasoning_parts.append(value)
|
||||||
|
elif block_type in ("tool_call_chunk", "tool_use"):
|
||||||
|
out["tool_call_chunks"].append(block)
|
||||||
|
if text_parts:
|
||||||
|
out["text"] = "".join(text_parts)
|
||||||
|
if reasoning_parts:
|
||||||
|
out["reasoning"] = "".join(reasoning_parts)
|
||||||
|
|
||||||
|
additional = getattr(chunk, "additional_kwargs", None) or {}
|
||||||
|
if isinstance(additional, dict):
|
||||||
|
extra_reasoning = additional.get("reasoning_content")
|
||||||
|
if isinstance(extra_reasoning, str) and extra_reasoning:
|
||||||
|
existing = out["reasoning"]
|
||||||
|
out["reasoning"] = (
|
||||||
|
(existing + extra_reasoning) if existing else extra_reasoning
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
|
||||||
|
if isinstance(extra_tool_chunks, list):
|
||||||
|
for tcc in extra_tool_chunks:
|
||||||
|
if isinstance(tcc, dict):
|
||||||
|
out["tool_call_chunks"].append(tcc)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def format_mentioned_surfsense_docs_as_context(
|
def format_mentioned_surfsense_docs_as_context(
|
||||||
documents: list[SurfsenseDocsDocument],
|
documents: list[SurfsenseDocsDocument],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -266,6 +352,7 @@ async def _stream_agent_events(
|
||||||
fallback_commit_search_space_id: int | None = None,
|
fallback_commit_search_space_id: int | None = None,
|
||||||
fallback_commit_created_by_id: str | None = None,
|
fallback_commit_created_by_id: str | None = None,
|
||||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
|
fallback_commit_thread_id: int | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Shared async generator that streams and formats astream_events from the agent.
|
"""Shared async generator that streams and formats astream_events from the agent.
|
||||||
|
|
||||||
|
|
@ -298,6 +385,41 @@ async def _stream_agent_events(
|
||||||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||||
called_update_memory: bool = False
|
called_update_memory: bool = False
|
||||||
|
|
||||||
|
# Reasoning-block streaming. We open a reasoning block on the
|
||||||
|
# first reasoning delta of a step, append deltas as they arrive, and
|
||||||
|
# close it when text starts (the model has switched to writing its
|
||||||
|
# answer) or ``on_chat_model_end`` fires for the model node. Reuses
|
||||||
|
# the same Vercel format-helpers as text-start/delta/end.
|
||||||
|
current_reasoning_id: str | None = None
|
||||||
|
|
||||||
|
# Streaming-parity v2 feature flag. When OFF we keep the legacy
|
||||||
|
# shape: str-only content, no reasoning blocks, no
|
||||||
|
# ``langchainToolCallId`` propagation. The schema migrations
|
||||||
|
# (135 / 136) ship unconditionally because they're forward-compatible.
|
||||||
|
parity_v2 = bool(get_flags().enable_stream_parity_v2)
|
||||||
|
|
||||||
|
# Best-effort attach of LangChain ``tool_call_id`` to the synthetic
|
||||||
|
# ``call_<run_id>`` card id we already emit. We accumulate
|
||||||
|
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
||||||
|
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
||||||
|
# authoritative id is later filled in at ``on_tool_end`` from
|
||||||
|
# ``ToolMessage.tool_call_id``.
|
||||||
|
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||||
|
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
||||||
|
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
||||||
|
# ``format_tool_output_available`` call automatically carries the
|
||||||
|
# authoritative id without duplicating the kwarg at every call site.
|
||||||
|
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
||||||
|
|
||||||
|
def _emit_tool_output(call_id: str, output: Any) -> str:
|
||||||
|
return streaming_service.format_tool_output_available(
|
||||||
|
call_id,
|
||||||
|
output,
|
||||||
|
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
||||||
|
)
|
||||||
|
|
||||||
def next_thinking_step_id() -> str:
|
def next_thinking_step_id() -> str:
|
||||||
nonlocal thinking_step_counter
|
nonlocal thinking_step_counter
|
||||||
thinking_step_counter += 1
|
thinking_step_counter += 1
|
||||||
|
|
@ -326,22 +448,61 @@ async def _stream_agent_events(
|
||||||
if "surfsense:internal" in event.get("tags", []):
|
if "surfsense:internal" in event.get("tags", []):
|
||||||
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
|
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk and hasattr(chunk, "content"):
|
if not chunk:
|
||||||
content = chunk.content
|
continue
|
||||||
if content and isinstance(content, str):
|
parts = _extract_chunk_parts(chunk)
|
||||||
if current_text_id is None:
|
|
||||||
completion_event = complete_current_step()
|
# Accumulate any tool_call_chunks for best-effort
|
||||||
if completion_event:
|
# correlation with ``on_tool_start`` below. We don't emit
|
||||||
yield completion_event
|
# anything here; the matching is done at tool-start time.
|
||||||
if just_finished_tool:
|
if parity_v2 and parts["tool_call_chunks"]:
|
||||||
last_active_step_id = None
|
for tcc in parts["tool_call_chunks"]:
|
||||||
last_active_step_title = ""
|
pending_tool_call_chunks.append(tcc)
|
||||||
last_active_step_items = []
|
|
||||||
just_finished_tool = False
|
reasoning_delta = parts["reasoning"]
|
||||||
current_text_id = streaming_service.generate_text_id()
|
text_delta = parts["text"]
|
||||||
yield streaming_service.format_text_start(current_text_id)
|
|
||||||
yield streaming_service.format_text_delta(current_text_id, content)
|
# Reasoning streaming. Open a reasoning block on first
|
||||||
accumulated_text += content
|
# delta; append every subsequent delta until text begins.
|
||||||
|
# When text starts we close the reasoning block first so the
|
||||||
|
# frontend sees the natural hand-off. Gated behind the
|
||||||
|
# parity-v2 flag so legacy deployments keep today's shape.
|
||||||
|
if parity_v2 and reasoning_delta:
|
||||||
|
if current_text_id is not None:
|
||||||
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
current_text_id = None
|
||||||
|
if current_reasoning_id is None:
|
||||||
|
completion_event = complete_current_step()
|
||||||
|
if completion_event:
|
||||||
|
yield completion_event
|
||||||
|
if just_finished_tool:
|
||||||
|
last_active_step_id = None
|
||||||
|
last_active_step_title = ""
|
||||||
|
last_active_step_items = []
|
||||||
|
just_finished_tool = False
|
||||||
|
current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||||
|
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
||||||
|
yield streaming_service.format_reasoning_delta(
|
||||||
|
current_reasoning_id, reasoning_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_delta:
|
||||||
|
if current_reasoning_id is not None:
|
||||||
|
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
||||||
|
current_reasoning_id = None
|
||||||
|
if current_text_id is None:
|
||||||
|
completion_event = complete_current_step()
|
||||||
|
if completion_event:
|
||||||
|
yield completion_event
|
||||||
|
if just_finished_tool:
|
||||||
|
last_active_step_id = None
|
||||||
|
last_active_step_title = ""
|
||||||
|
last_active_step_items = []
|
||||||
|
just_finished_tool = False
|
||||||
|
current_text_id = streaming_service.generate_text_id()
|
||||||
|
yield streaming_service.format_text_start(current_text_id)
|
||||||
|
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||||
|
accumulated_text += text_delta
|
||||||
|
|
||||||
elif event_type == "on_tool_start":
|
elif event_type == "on_tool_start":
|
||||||
active_tool_depth += 1
|
active_tool_depth += 1
|
||||||
|
|
@ -581,7 +742,39 @@ async def _stream_agent_events(
|
||||||
if run_id
|
if run_id
|
||||||
else streaming_service.generate_tool_call_id()
|
else streaming_service.generate_tool_call_id()
|
||||||
)
|
)
|
||||||
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
|
||||||
|
# Best-effort attach the LangChain ``tool_call_id``. We
|
||||||
|
# pop the first chunk in ``pending_tool_call_chunks`` whose
|
||||||
|
# name matches; if none match (the chunked args may not yet
|
||||||
|
# carry a ``name`` field, or the model skipped the chunked
|
||||||
|
# form) we leave ``langchainToolCallId`` unset for now and
|
||||||
|
# fill it in authoritatively at ``on_tool_end`` from
|
||||||
|
# ``ToolMessage.tool_call_id``.
|
||||||
|
langchain_tool_call_id: str | None = None
|
||||||
|
if parity_v2 and pending_tool_call_chunks:
|
||||||
|
matched_idx: int | None = None
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is None:
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is not None:
|
||||||
|
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||||
|
candidate = matched.get("id")
|
||||||
|
if isinstance(candidate, str) and candidate:
|
||||||
|
langchain_tool_call_id = candidate
|
||||||
|
if run_id:
|
||||||
|
lc_tool_call_id_by_run[run_id] = candidate
|
||||||
|
|
||||||
|
yield streaming_service.format_tool_input_start(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
|
)
|
||||||
# Sanitize tool_input: strip runtime-injected non-serializable
|
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||||
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||||
if isinstance(tool_input, dict):
|
if isinstance(tool_input, dict):
|
||||||
|
|
@ -598,6 +791,7 @@ async def _stream_agent_events(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
_safe_input,
|
_safe_input,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif event_type == "on_tool_end":
|
elif event_type == "on_tool_end":
|
||||||
|
|
@ -639,6 +833,23 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
completed_step_ids.add(original_step_id)
|
completed_step_ids.add(original_step_id)
|
||||||
|
|
||||||
|
# Authoritative LangChain tool_call_id from the returned
|
||||||
|
# ``ToolMessage``. Falls back to whatever we matched
|
||||||
|
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
||||||
|
# if the output isn't a ToolMessage. The value is stored in
|
||||||
|
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
||||||
|
# picks it up for every output emit below. Stays None when
|
||||||
|
# parity_v2 is off so legacy emit paths are untouched.
|
||||||
|
current_lc_tool_call_id["value"] = None
|
||||||
|
if parity_v2:
|
||||||
|
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||||
|
if isinstance(authoritative, str) and authoritative:
|
||||||
|
current_lc_tool_call_id["value"] = authoritative
|
||||||
|
if run_id:
|
||||||
|
lc_tool_call_id_by_run[run_id] = authoritative
|
||||||
|
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||||
|
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||||
|
|
||||||
if tool_name == "read_file":
|
if tool_name == "read_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
|
|
@ -938,7 +1149,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items = []
|
last_active_step_items = []
|
||||||
|
|
||||||
if tool_name == "generate_podcast":
|
if tool_name == "generate_podcast":
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -963,7 +1174,7 @@ async def _stream_agent_events(
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
elif tool_name == "generate_video_presentation":
|
elif tool_name == "generate_video_presentation":
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -991,7 +1202,7 @@ async def _stream_agent_events(
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
elif tool_name == "generate_image":
|
elif tool_name == "generate_image":
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -1018,12 +1229,12 @@ async def _stream_agent_events(
|
||||||
display_output["content_preview"] = (
|
display_output["content_preview"] = (
|
||||||
content[:500] + "..." if len(content) > 500 else content
|
content[:500] + "..." if len(content) > 500 else content
|
||||||
)
|
)
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
display_output,
|
display_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{"result": tool_output},
|
{"result": tool_output},
|
||||||
)
|
)
|
||||||
|
|
@ -1051,7 +1262,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
result_text = _tool_output_to_text(tool_output)
|
result_text = _tool_output_to_text(tool_output)
|
||||||
if _tool_output_has_error(tool_output):
|
if _tool_output_has_error(tool_output):
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{
|
{
|
||||||
"status": "error",
|
"status": "error",
|
||||||
|
|
@ -1060,7 +1271,7 @@ async def _stream_agent_events(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{
|
{
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
|
|
@ -1070,7 +1281,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
elif tool_name == "generate_report":
|
elif tool_name == "generate_report":
|
||||||
# Stream the full report result so frontend can render the ReportCard
|
# Stream the full report result so frontend can render the ReportCard
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -1097,7 +1308,7 @@ async def _stream_agent_events(
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
elif tool_name == "generate_resume":
|
elif tool_name == "generate_resume":
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -1148,7 +1359,7 @@ async def _stream_agent_events(
|
||||||
"update_confluence_page",
|
"update_confluence_page",
|
||||||
"delete_confluence_page",
|
"delete_confluence_page",
|
||||||
):
|
):
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output
|
tool_output
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -1176,7 +1387,7 @@ async def _stream_agent_events(
|
||||||
if fpath and fpath not in result.sandbox_files:
|
if fpath and fpath not in result.sandbox_files:
|
||||||
result.sandbox_files.append(fpath)
|
result.sandbox_files.append(fpath)
|
||||||
|
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{
|
{
|
||||||
"exit_code": exit_code,
|
"exit_code": exit_code,
|
||||||
|
|
@ -1211,12 +1422,12 @@ async def _stream_agent_events(
|
||||||
citations[chunk_url]["snippet"] = (
|
citations[chunk_url]["snippet"] = (
|
||||||
content[:200] + "…" if len(content) > 200 else content
|
content[:200] + "…" if len(content) > 200 else content
|
||||||
)
|
)
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{"status": "completed", "citations": citations},
|
{"status": "completed", "citations": citations},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield streaming_service.format_tool_output_available(
|
yield _emit_tool_output(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
{"status": "completed", "result_length": len(str(tool_output))},
|
{"status": "completed", "result_length": len(str(tool_output))},
|
||||||
)
|
)
|
||||||
|
|
@ -1274,6 +1485,25 @@ async def _stream_agent_events(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif event_type == "on_custom_event" and event.get("name") == "action_log":
|
||||||
|
# Surface a freshly committed AgentActionLog row so the chat
|
||||||
|
# tool card can render its Revert button immediately.
|
||||||
|
data = event.get("data", {})
|
||||||
|
if data.get("id") is not None:
|
||||||
|
yield streaming_service.format_data("action-log", data)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
event_type == "on_custom_event"
|
||||||
|
and event.get("name") == "action_log_updated"
|
||||||
|
):
|
||||||
|
# Reversibility flipped in kb_persistence after the SAVEPOINT
|
||||||
|
# for a destructive op (rm/rmdir/move/edit/write) committed.
|
||||||
|
# Frontend uses this to flip the card's Revert
|
||||||
|
# button on without re-fetching the actions list.
|
||||||
|
data = event.get("data", {})
|
||||||
|
if data.get("id") is not None:
|
||||||
|
yield streaming_service.format_data("action-log-updated", data)
|
||||||
|
|
||||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
|
@ -1291,11 +1521,12 @@ async def _stream_agent_events(
|
||||||
|
|
||||||
# Safety net: if astream_events was cancelled before
|
# Safety net: if astream_events was cancelled before
|
||||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||||
# (dirty_paths / staged_dirs / pending_moves) will still be in the
|
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
|
||||||
# checkpointed state. Run the SAME shared commit helper here so the
|
# pending_dir_deletes) will still be in the checkpointed state. Run
|
||||||
# turn's writes don't get lost on client disconnect, then push the
|
# the SAME shared commit helper here so the turn's writes don't get
|
||||||
# delta back into the graph using `as_node=...` so reducers fire as if
|
# lost on client disconnect, then push the delta back into the graph
|
||||||
# the after_agent hook produced it.
|
# using `as_node=...` so reducers fire as if the after_agent hook
|
||||||
|
# produced it.
|
||||||
if (
|
if (
|
||||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||||
and fallback_commit_search_space_id is not None
|
and fallback_commit_search_space_id is not None
|
||||||
|
|
@ -1303,6 +1534,8 @@ async def _stream_agent_events(
|
||||||
(state_values.get("dirty_paths") or [])
|
(state_values.get("dirty_paths") or [])
|
||||||
or (state_values.get("staged_dirs") or [])
|
or (state_values.get("staged_dirs") or [])
|
||||||
or (state_values.get("pending_moves") or [])
|
or (state_values.get("pending_moves") or [])
|
||||||
|
or (state_values.get("pending_deletes") or [])
|
||||||
|
or (state_values.get("pending_dir_deletes") or [])
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
@ -1311,6 +1544,7 @@ async def _stream_agent_events(
|
||||||
search_space_id=fallback_commit_search_space_id,
|
search_space_id=fallback_commit_search_space_id,
|
||||||
created_by_id=fallback_commit_created_by_id,
|
created_by_id=fallback_commit_created_by_id,
|
||||||
filesystem_mode=fallback_commit_filesystem_mode,
|
filesystem_mode=fallback_commit_filesystem_mode,
|
||||||
|
thread_id=fallback_commit_thread_id,
|
||||||
dispatch_events=False,
|
dispatch_events=False,
|
||||||
)
|
)
|
||||||
if delta:
|
if delta:
|
||||||
|
|
@ -1726,6 +1960,17 @@ async def stream_new_chat(
|
||||||
yield streaming_service.format_message_start()
|
yield streaming_service.format_message_start()
|
||||||
yield streaming_service.format_start_step()
|
yield streaming_service.format_start_step()
|
||||||
|
|
||||||
|
# Surface the per-turn correlation id at the very start of the
|
||||||
|
# stream so the frontend can stamp it onto the in-flight
|
||||||
|
# assistant message and replay it via ``appendMessage``
|
||||||
|
# for durable storage. Tool/action-log events DO carry it later,
|
||||||
|
# but pure-text turns never produce action-log events; this
|
||||||
|
# event guarantees the frontend learns the turn id regardless.
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"turn-info",
|
||||||
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
|
)
|
||||||
|
|
||||||
# Initial thinking step - analyzing the request
|
# Initial thinking step - analyzing the request
|
||||||
if mentioned_surfsense_docs:
|
if mentioned_surfsense_docs:
|
||||||
initial_title = "Analyzing referenced content"
|
initial_title = "Analyzing referenced content"
|
||||||
|
|
@ -1876,6 +2121,7 @@ async def stream_new_chat(
|
||||||
if filesystem_selection
|
if filesystem_selection
|
||||||
else FilesystemMode.CLOUD
|
else FilesystemMode.CLOUD
|
||||||
),
|
),
|
||||||
|
fallback_commit_thread_id=chat_id,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -2308,6 +2554,13 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
yield streaming_service.format_message_start()
|
yield streaming_service.format_message_start()
|
||||||
yield streaming_service.format_start_step()
|
yield streaming_service.format_start_step()
|
||||||
|
# Same rationale as ``stream_new_chat``: emit the turn id so
|
||||||
|
# resumed streams can be persisted with their correlation id
|
||||||
|
# intact.
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"turn-info",
|
||||||
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
|
)
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
|
|
@ -2325,6 +2578,7 @@ async def stream_resume_chat(
|
||||||
if filesystem_selection
|
if filesystem_selection
|
||||||
else FilesystemMode.CLOUD
|
else FilesystemMode.CLOUD
|
||||||
),
|
),
|
||||||
|
fallback_commit_thread_id=chat_id,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
||||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeRuntime:
|
||||||
|
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
|
||||||
|
|
||||||
|
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
|
||||||
|
to populate the new ``chat_turn_id`` column (see migration 135).
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _FakeRequest:
|
class _FakeRequest:
|
||||||
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
||||||
|
|
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
|
||||||
"args": {"color": "red", "size": 3},
|
"args": {"color": "red", "size": 3},
|
||||||
"id": "tc-abc",
|
"id": "tc-abc",
|
||||||
},
|
},
|
||||||
|
runtime=_FakeRuntime(
|
||||||
|
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||||
handler = AsyncMock(return_value=result_msg)
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
|
||||||
assert row.error is None
|
assert row.error is None
|
||||||
assert row.reverse_descriptor is None
|
assert row.reverse_descriptor is None
|
||||||
assert row.reversible is False
|
assert row.reversible is False
|
||||||
|
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
|
||||||
|
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
|
||||||
|
assert row.tool_call_id == "tc-abc"
|
||||||
|
assert row.turn_id == "tc-abc"
|
||||||
|
assert row.chat_turn_id == "42:1700000000000"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_turn_id_none_when_runtime_missing(
|
||||||
|
self, patch_get_flags, fake_session_factory
|
||||||
|
) -> None:
|
||||||
|
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
|
||||||
|
captured, factory = fake_session_factory
|
||||||
|
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
|
||||||
|
runtime=None,
|
||||||
|
)
|
||||||
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
row = captured["rows"][0]
|
||||||
|
assert row.tool_call_id == "tc-1"
|
||||||
|
assert row.chat_turn_id is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_writes_row_on_failure_and_reraises(
|
async def test_writes_row_on_failure_and_reraises(
|
||||||
|
|
@ -293,6 +333,76 @@ class TestReverseDescriptor:
|
||||||
assert row.reversible is False
|
assert row.reversible is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestActionLogDispatch:
|
||||||
|
"""Verify ``adispatch_custom_event`` fires after commit."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatches_action_log_event_on_success(
|
||||||
|
self, patch_get_flags, fake_session_factory
|
||||||
|
) -> None:
|
||||||
|
_captured, factory = fake_session_factory
|
||||||
|
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={
|
||||||
|
"name": "make_widget",
|
||||||
|
"args": {"color": "red"},
|
||||||
|
"id": "tc-evt",
|
||||||
|
},
|
||||||
|
runtime=_FakeRuntime(
|
||||||
|
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
|
||||||
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
||||||
|
dispatch_mock = AsyncMock()
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
|
patch(
|
||||||
|
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||||
|
dispatch_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
dispatch_mock.assert_awaited_once()
|
||||||
|
call_args = dispatch_mock.await_args
|
||||||
|
assert call_args is not None
|
||||||
|
assert call_args.args[0] == "action_log"
|
||||||
|
payload = call_args.args[1]
|
||||||
|
assert payload["lc_tool_call_id"] == "tc-evt"
|
||||||
|
assert payload["chat_turn_id"] == "42:1700000000000"
|
||||||
|
assert payload["tool_name"] == "make_widget"
|
||||||
|
assert payload["reversible"] is False
|
||||||
|
assert payload["reverse_descriptor_present"] is False
|
||||||
|
assert payload["error"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
|
||||||
|
"""If commit fails the dispatch is suppressed (no row to surface)."""
|
||||||
|
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||||
|
)
|
||||||
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||||
|
dispatch_mock = AsyncMock()
|
||||||
|
|
||||||
|
def _exploding_session():
|
||||||
|
raise RuntimeError("DB is down")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||||
|
patch(
|
||||||
|
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||||
|
dispatch_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
dispatch_mock.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
class TestArgsTruncation:
|
class TestArgsTruncation:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_huge_args_payload_is_truncated(
|
async def test_huge_args_payload_is_truncated(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""Tests for the desktop-mode safety ruleset.
|
||||||
|
|
||||||
|
In desktop mode the agent operates against the user's real disk with no
|
||||||
|
revision history, so destructive filesystem operations must require
|
||||||
|
explicit approval. These tests pin the set of tools that get the ``ask``
|
||||||
|
gate so it cannot silently regress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||||
|
from app.agents.new_chat.permissions import (
|
||||||
|
Rule,
|
||||||
|
Ruleset,
|
||||||
|
aggregate_action,
|
||||||
|
evaluate_many,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
|
||||||
|
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
|
||||||
|
# copy here means the rule contract has a focused regression test even when
|
||||||
|
# the larger graph-build helper is hard to instantiate in unit tests.
|
||||||
|
DESKTOP_SAFETY_RULESET = Ruleset(
|
||||||
|
rules=[
|
||||||
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
SURFSENSE_DEFAULTS = Ruleset(
|
||||||
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
|
origin="surfsense_defaults",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
|
||||||
|
rules = evaluate_many(tool_name, [tool_name], *rulesets)
|
||||||
|
return aggregate_action(rules)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDesktopSafetyRulesGateDestructiveOps:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_name",
|
||||||
|
["rm", "rmdir", "move_file", "edit_file", "write_file"],
|
||||||
|
)
|
||||||
|
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
|
||||||
|
# surfsense_defaults says "allow */*"; desktop_safety must override
|
||||||
|
# because it's layered later (last-match-wins).
|
||||||
|
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "ask", (
|
||||||
|
f"{tool_name} must require approval in desktop mode "
|
||||||
|
f"(no revert path on real disk); got {action!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_name",
|
||||||
|
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
|
||||||
|
)
|
||||||
|
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
|
||||||
|
# Read-only and trivially-reversible tools must NOT get gated —
|
||||||
|
# otherwise every navigation in desktop mode pops an interrupt.
|
||||||
|
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "allow", (
|
||||||
|
f"{tool_name} should not be gated in desktop mode; got {action!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDesktopSafetyOverridesAllowDefault:
|
||||||
|
def test_layer_order_last_match_wins(self) -> None:
|
||||||
|
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
|
||||||
|
# default would win and the safety net would be inert. This test
|
||||||
|
# protects against accidentally swapping the rulesets in
|
||||||
|
# ``_build_compiled_agent_blocking``.
|
||||||
|
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
|
||||||
|
# Layered "wrong way" — the broad allow now wins.
|
||||||
|
assert action == "allow"
|
||||||
|
|
||||||
|
# Correct order: defaults < desktop_safety -> ask wins.
|
||||||
|
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "ask"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionMiddlewareIntegration:
|
||||||
|
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import RejectedError
|
||||||
|
|
||||||
|
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
|
||||||
|
# Stub the interrupt to a "reject" decision so we can assert the
|
||||||
|
# ask path was taken without spinning up the LangGraph runtime.
|
||||||
|
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "rm",
|
||||||
|
"args": {"path": "/Users/me/Documents/important.docx"},
|
||||||
|
"id": "tc-rm",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
class _FakeRuntime:
|
||||||
|
config: dict = {"configurable": {"thread_id": "test"}}
|
||||||
|
|
||||||
|
with pytest.raises(RejectedError):
|
||||||
|
mw.after_model(state, _FakeRuntime())
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Tests for the default auto-approval list in ``hitl.request_approval``.
|
||||||
|
|
||||||
|
These pin the policy that low-stakes connector creation tools (drafts,
|
||||||
|
new-file creates) skip the HITL interrupt by default. Without this set,
|
||||||
|
every "draft my newsletter" turn used to fire ~3 interrupts before any
|
||||||
|
useful work happened.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import (
|
||||||
|
DEFAULT_AUTO_APPROVED_TOOLS,
|
||||||
|
HITLResult,
|
||||||
|
request_approval,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultAutoApprovedToolsList:
|
||||||
|
def test_set_contains_expected_creation_tools(self) -> None:
|
||||||
|
# If anyone changes the policy list, we want a single test to
|
||||||
|
# update so the contract is explicit. Keep this in sync with
|
||||||
|
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
|
||||||
|
expected = {
|
||||||
|
"create_gmail_draft",
|
||||||
|
"update_gmail_draft",
|
||||||
|
"create_notion_page",
|
||||||
|
"create_confluence_page",
|
||||||
|
"create_google_drive_file",
|
||||||
|
"create_dropbox_file",
|
||||||
|
"create_onedrive_file",
|
||||||
|
}
|
||||||
|
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
|
||||||
|
|
||||||
|
def test_set_is_immutable(self) -> None:
|
||||||
|
# frozenset prevents accidental at-runtime mutation that would
|
||||||
|
# silently widen the auto-approval surface.
|
||||||
|
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
|
||||||
|
|
||||||
|
def test_send_tools_are_not_auto_approved(self) -> None:
|
||||||
|
# External-broadcast tools must always prompt.
|
||||||
|
for tool_name in (
|
||||||
|
"send_gmail_email",
|
||||||
|
"send_discord_message",
|
||||||
|
"send_teams_message",
|
||||||
|
"delete_notion_page",
|
||||||
|
"create_calendar_event",
|
||||||
|
"delete_calendar_event",
|
||||||
|
):
|
||||||
|
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
|
||||||
|
f"{tool_name} must remain HITL-gated"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestApprovalAutoBypass:
|
||||||
|
def test_auto_approved_tool_skips_interrupt(self) -> None:
|
||||||
|
# No interrupt mock set up — if the function attempted to call
|
||||||
|
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
|
||||||
|
# The fact that we get a clean HITLResult proves the bypass.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="gmail_draft_creation",
|
||||||
|
tool_name="create_gmail_draft",
|
||||||
|
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||||
|
)
|
||||||
|
assert isinstance(result, HITLResult)
|
||||||
|
assert result.rejected is False
|
||||||
|
assert result.decision_type == "auto_approved"
|
||||||
|
# Original params are preserved untouched (no user edits possible).
|
||||||
|
assert result.params == {
|
||||||
|
"to": "alice@example.com",
|
||||||
|
"subject": "hi",
|
||||||
|
"body": "hey",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
|
||||||
|
# A tool NOT in the default list must reach ``langgraph.interrupt``.
|
||||||
|
# Outside a runnable context that call raises a RuntimeError —
|
||||||
|
# which is exactly the signal we want: the bypass did NOT fire.
|
||||||
|
with pytest.raises(RuntimeError, match="runnable context"):
|
||||||
|
request_approval(
|
||||||
|
action_type="gmail_email_send",
|
||||||
|
tool_name="send_gmail_email",
|
||||||
|
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_user_trusted_tools_still_take_precedence(self) -> None:
|
||||||
|
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
|
||||||
|
# was checked BEFORE the default list and must keep working
|
||||||
|
# for tools outside the default list.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="mcp_tool_call",
|
||||||
|
tool_name="my_custom_mcp_tool",
|
||||||
|
params={"x": 1},
|
||||||
|
trusted_tools=["my_custom_mcp_tool"],
|
||||||
|
)
|
||||||
|
assert result.decision_type == "trusted"
|
||||||
|
assert result.rejected is False
|
||||||
|
|
||||||
|
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
|
||||||
|
# When trusted_tools is empty and tool is in the default list,
|
||||||
|
# we should still bypass — proves the order in request_approval.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="notion_page_creation",
|
||||||
|
tool_name="create_notion_page",
|
||||||
|
params={"title": "Plan"},
|
||||||
|
trusted_tools=[],
|
||||||
|
)
|
||||||
|
assert result.decision_type == "auto_approved"
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
|
||||||
|
|
||||||
|
The tools build ``Command(update=...)`` payloads that the persistence
|
||||||
|
middleware applies at end of turn. These tests stub out the backend and
|
||||||
|
runtime to assert the staging payload shape:
|
||||||
|
|
||||||
|
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
|
||||||
|
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
|
||||||
|
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
|
||||||
|
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
|
||||||
|
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
|
||||||
|
* ``KBPostgresBackend`` view-helpers honor staged deletes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._filesystem_mode = mode
|
||||||
|
middleware._custom_tool_descriptions = {}
|
||||||
|
return middleware
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
|
||||||
|
state = state or {}
|
||||||
|
state.setdefault("cwd", "/documents")
|
||||||
|
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
class _KBBackendStub(KBPostgresBackend):
|
||||||
|
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
|
||||||
|
|
||||||
|
We bypass the real ``__init__`` (which expects a runtime + DB session)
|
||||||
|
and inject just the methods the rm/rmdir tools touch. The class
|
||||||
|
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
|
||||||
|
inside the tools happy, which is what gates them from the desktop
|
||||||
|
code path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, children=None, file_data=None) -> None:
|
||||||
|
self.als_info = AsyncMock(return_value=children or [])
|
||||||
|
self._load_file_data = AsyncMock(
|
||||||
|
return_value=(file_data, 17) if file_data is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
|
||||||
|
return _KBBackendStub(children=children, file_data=file_data)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_backend(middleware, backend):
|
||||||
|
"""Inject a backend resolver onto the middleware test instance."""
|
||||||
|
middleware._get_backend = lambda runtime: backend
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rm
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmStaging:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stages_delete_and_tombstones_state(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
"files": {"/documents/notes.md": {"content": ["hello"]}},
|
||||||
|
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||||
|
},
|
||||||
|
tool_call_id="tc-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update"), f"expected Command, got {result!r}"
|
||||||
|
update = result.update
|
||||||
|
assert update["pending_deletes"] == [
|
||||||
|
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
|
||||||
|
]
|
||||||
|
assert update["files"] == {"/documents/notes.md": None}
|
||||||
|
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_documents_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rm" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rm" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_directory_via_staged_dirs(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"staged_dirs": ["/documents/team-x"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/team-x", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "directory" in result.lower()
|
||||||
|
assert "rmdir" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_directory_via_listing(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(
|
||||||
|
m,
|
||||||
|
_make_backend_stub(
|
||||||
|
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/foo", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "directory" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_anonymous_doc(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"kb_anon_doc": {
|
||||||
|
"path": "/documents/uploaded.xml",
|
||||||
|
"title": "uploaded",
|
||||||
|
"content": "",
|
||||||
|
"chunks": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "read-only" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drops_path_from_dirty_paths(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"files": {"/documents/notes.md": {"content": ["x"]}},
|
||||||
|
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||||
|
"dirty_paths": ["/documents/notes.md"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
update = result.update
|
||||||
|
# First element is _CLEAR sentinel; the rest must NOT contain the
|
||||||
|
# rm'd path.
|
||||||
|
dirty = update.get("dirty_paths") or []
|
||||||
|
assert "/documents/notes.md" not in dirty[1:]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rmdir
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmdirStaging:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stages_dir_delete_when_empty_and_db_backed(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
backend = _bind_backend(m, _make_backend_stub(children=[]))
|
||||||
|
# Override _load_file_data to return None (folder, not a file) and
|
||||||
|
# parent listing to claim the folder exists.
|
||||||
|
backend._load_file_data = AsyncMock(return_value=None)
|
||||||
|
backend.als_info = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
[], # children of /documents/proj
|
||||||
|
[
|
||||||
|
{"path": "/documents/proj", "is_dir": True},
|
||||||
|
], # parent listing
|
||||||
|
]
|
||||||
|
)
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
},
|
||||||
|
tool_call_id="tc-rd",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
assert update["pending_dir_deletes"] == [
|
||||||
|
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_non_empty(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(
|
||||||
|
m,
|
||||||
|
_make_backend_stub(
|
||||||
|
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "not empty" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unstages_same_turn_mkdir(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[]))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
"staged_dirs": ["/documents/scratch"],
|
||||||
|
},
|
||||||
|
tool_call_id="tc-rd",
|
||||||
|
)
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/scratch", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
assert "pending_dir_deletes" not in update
|
||||||
|
# _CLEAR sentinel + remaining items (in this case, none).
|
||||||
|
staged_after = update["staged_dirs"]
|
||||||
|
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||||
|
assert "/documents/scratch" not in staged_after[1:]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
for victim in ("/", "/documents"):
|
||||||
|
result = await tool.coroutine(victim, runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rmdir" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/proj"})
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "cwd" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_ancestor_of_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/proj/sub"})
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "cwd" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_files(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "is a file" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KBPostgresBackend view filter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBPostgresBackendDeleteFilter:
|
||||||
|
"""als_info / glob / grep should suppress paths queued for delete."""
|
||||||
|
|
||||||
|
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
|
||||||
|
runtime = SimpleNamespace(state=state)
|
||||||
|
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
|
||||||
|
return backend
|
||||||
|
|
||||||
|
def test_pending_filesystem_view_returns_deleted_paths(self):
|
||||||
|
backend = self._make_backend(
|
||||||
|
{
|
||||||
|
"pending_deletes": [
|
||||||
|
{"path": "/documents/x.md", "tool_call_id": "t1"},
|
||||||
|
],
|
||||||
|
"pending_dir_deletes": [
|
||||||
|
{"path": "/documents/d1", "tool_call_id": "t2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
|
||||||
|
assert "/documents/x.md" in removed
|
||||||
|
assert "/documents/d1" in deleted_dirs
|
||||||
|
assert alias == {}
|
||||||
|
|
||||||
|
def test_dir_suppressed_covers_descendants(self):
|
||||||
|
backend = self._make_backend({})
|
||||||
|
deleted_dirs = {"/documents/d"}
|
||||||
|
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
|
||||||
|
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
|
||||||
|
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
|
||||||
|
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)
|
||||||
|
|
@ -98,10 +98,54 @@ class TestInitialFilesystemState:
|
||||||
state = _initial_filesystem_state()
|
state = _initial_filesystem_state()
|
||||||
assert state["cwd"] == "/documents"
|
assert state["cwd"] == "/documents"
|
||||||
assert state["staged_dirs"] == []
|
assert state["staged_dirs"] == []
|
||||||
|
assert state["staged_dir_tool_calls"] == {}
|
||||||
assert state["pending_moves"] == []
|
assert state["pending_moves"] == []
|
||||||
|
assert state["pending_deletes"] == []
|
||||||
|
assert state["pending_dir_deletes"] == []
|
||||||
assert state["doc_id_by_path"] == {}
|
assert state["doc_id_by_path"] == {}
|
||||||
assert state["dirty_paths"] == []
|
assert state["dirty_paths"] == []
|
||||||
|
assert state["dirty_path_tool_calls"] == {}
|
||||||
assert state["kb_priority"] == []
|
assert state["kb_priority"] == []
|
||||||
assert state["kb_matched_chunk_ids"] == {}
|
assert state["kb_matched_chunk_ids"] == {}
|
||||||
assert state["kb_anon_doc"] is None
|
assert state["kb_anon_doc"] is None
|
||||||
assert state["tree_version"] == 0
|
assert state["tree_version"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiEditSamePathCoalescing:
|
||||||
|
"""Multi-edit-same-path turns must coalesce into ONE binding record.
|
||||||
|
|
||||||
|
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
|
||||||
|
tool_call_id that produced the current state on disk. Because
|
||||||
|
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
|
||||||
|
edit doesn't append a new path entry — and because
|
||||||
|
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
|
||||||
|
overwrite, the LATEST tool_call_id wins. That's the correct behavior
|
||||||
|
for snapshotting: revert restores to the pre-mutation state, and
|
||||||
|
multiple back-to-back edits in one turn coalesce into a single
|
||||||
|
revisible op (the user sees ONE Revert button per turn-per-path,
|
||||||
|
not N).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_dirty_paths_dedupes_repeated_writes(self):
|
||||||
|
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
|
||||||
|
# to the same path produce one entry, not two.
|
||||||
|
first = _add_unique_reducer([], ["/documents/a.md"])
|
||||||
|
second = _add_unique_reducer(first, ["/documents/a.md"])
|
||||||
|
assert second == ["/documents/a.md"]
|
||||||
|
|
||||||
|
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
|
||||||
|
# First write tags the path with tcid-1.
|
||||||
|
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
|
||||||
|
# Second write to the same path tags it with tcid-2 (latest wins).
|
||||||
|
merged = _dict_merge_with_tombstones_reducer(
|
||||||
|
merged, {"/documents/a.md": "tcid-2"}
|
||||||
|
)
|
||||||
|
assert merged == {"/documents/a.md": "tcid-2"}
|
||||||
|
|
||||||
|
def test_rm_tombstones_dirty_path_tool_call(self):
|
||||||
|
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
|
||||||
|
# prevent a stale binding from leaking past the delete.
|
||||||
|
merged = _dict_merge_with_tombstones_reducer(
|
||||||
|
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
|
||||||
|
)
|
||||||
|
assert merged == {}
|
||||||
|
|
|
||||||
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
|
||||||
|
|
||||||
|
A full apply/rollback test would require a live Postgres; here we verify
|
||||||
|
the migration module's static contract:
|
||||||
|
|
||||||
|
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
|
||||||
|
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
|
||||||
|
(one for ``document_revisions.document_id``, one for
|
||||||
|
``folder_revisions.folder_id``).
|
||||||
|
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
|
||||||
|
orphaned revisions.
|
||||||
|
|
||||||
|
If any of these invariants regress the snapshot/revert pipeline silently
|
||||||
|
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
|
||||||
|
migration "down" or never ran it at all.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
_MIGRATION_PATH = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "alembic"
|
||||||
|
/ "versions"
|
||||||
|
/ "134_relax_revision_fks.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migration():
|
||||||
|
"""Load the migration module by file path (no package import needed)."""
|
||||||
|
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
|
||||||
|
assert spec and spec.loader, "could not load migration spec"
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_chain_revision_ids() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
# The migration file uses short numeric revision IDs to match the
|
||||||
|
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
|
||||||
|
# filename is documentation, not the canonical revision string.
|
||||||
|
assert getattr(module, "revision", None) == "134"
|
||||||
|
assert getattr(module, "down_revision", None) == "133"
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_exposes_upgrade_and_downgrade() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
upgrade = getattr(module, "upgrade", None)
|
||||||
|
downgrade = getattr(module, "downgrade", None)
|
||||||
|
assert callable(upgrade), "upgrade() is required"
|
||||||
|
assert callable(downgrade), "downgrade() is required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
src = inspect.getsource(module.upgrade)
|
||||||
|
assert "document_revisions" in src
|
||||||
|
assert "folder_revisions" in src
|
||||||
|
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
|
||||||
|
# of the migration: snapshots must outlive their parent row.
|
||||||
|
assert src.count('ondelete="SET NULL"') >= 2
|
||||||
|
# And the ``document_id`` / ``folder_id`` columns become nullable.
|
||||||
|
assert "nullable=True" in src
|
||||||
|
|
||||||
|
|
||||||
|
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
src = inspect.getsource(module.downgrade)
|
||||||
|
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
|
||||||
|
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
|
||||||
|
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
|
||||||
|
# Then restore the original CASCADE/NOT NULL contract.
|
||||||
|
assert src.count('ondelete="CASCADE"') >= 2
|
||||||
|
assert "nullable=False" in src
|
||||||
|
|
@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
|
||||||
"edit_file",
|
"edit_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"list_tree",
|
"list_tree",
|
||||||
"grep",
|
"grep",
|
||||||
):
|
):
|
||||||
|
|
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
|
||||||
"edit_file",
|
"edit_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"list_tree",
|
"list_tree",
|
||||||
"grep",
|
"grep",
|
||||||
):
|
):
|
||||||
|
|
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
|
||||||
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||||
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||||
|
|
||||||
|
def test_cloud_descs_include_rm_and_rmdir(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||||
|
assert "rm" in descs and "rmdir" in descs
|
||||||
|
assert "Deletes a single file" in descs["rm"]
|
||||||
|
assert "Deletes an empty directory" in descs["rmdir"]
|
||||||
|
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
|
||||||
|
|
||||||
|
def test_desktop_descs_warn_about_irreversibility(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||||
|
assert "NOT reversible" in descs["rm"]
|
||||||
|
assert "NOT reversible" in descs["rmdir"]
|
||||||
|
|
||||||
def test_sandbox_addendum_appended_when_available(self):
|
def test_sandbox_addendum_appended_when_available(self):
|
||||||
prompt = _build_filesystem_system_prompt(
|
prompt = _build_filesystem_system_prompt(
|
||||||
FilesystemMode.CLOUD, sandbox_available=True
|
FilesystemMode.CLOUD, sandbox_available=True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,309 @@
|
||||||
|
"""Unit tests for the kb_persistence snapshot helpers.
|
||||||
|
|
||||||
|
The full ``commit_staged_filesystem_state`` body exercises a real session
|
||||||
|
in integration tests; here we verify the building blocks used by the
|
||||||
|
snapshot/revert pipeline:
|
||||||
|
|
||||||
|
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
|
||||||
|
(regression guard against the N+1 lookup pattern).
|
||||||
|
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
|
||||||
|
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
|
||||||
|
shape the snapshot helpers consume.
|
||||||
|
|
||||||
|
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
|
||||||
|
the assertions run in milliseconds and don't require Postgres.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import kb_persistence
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
self._scalar = scalar
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self._rows)
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> Any:
|
||||||
|
return self._scalar
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.execute = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_issues_single_query() -> None:
|
||||||
|
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
|
||||||
|
session = _FakeSession()
|
||||||
|
session.execute.return_value = _FakeResult(
|
||||||
|
rows=[
|
||||||
|
MagicMock(id=11, tool_call_id="tc-a"),
|
||||||
|
MagicMock(id=22, tool_call_id="tc-b"),
|
||||||
|
MagicMock(id=33, tool_call_id="tc-c"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=1,
|
||||||
|
tool_call_ids={"tc-a", "tc-b", "tc-c"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
|
||||||
|
assert session.execute.await_count == 1, (
|
||||||
|
"Snapshot binding must batch into ONE query; got "
|
||||||
|
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=None,
|
||||||
|
tool_call_ids={"tc-a"},
|
||||||
|
)
|
||||||
|
assert mapping == {}
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=42,
|
||||||
|
tool_call_ids=set(),
|
||||||
|
)
|
||||||
|
assert mapping == {}
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
|
||||||
|
assert session.execute.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
|
||||||
|
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.content = "body"
|
||||||
|
doc.title = "notes.md"
|
||||||
|
doc.folder_id = 7
|
||||||
|
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
|
||||||
|
|
||||||
|
payload = kb_persistence._doc_revision_payload(
|
||||||
|
doc, chunks_before=[{"content": "x"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload["title_before"] == "notes.md"
|
||||||
|
assert payload["folder_id_before"] == 7
|
||||||
|
assert payload["content_before"] == "body"
|
||||||
|
assert payload["chunks_before"] == [{"content": "x"}]
|
||||||
|
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_revision_payload_handles_missing_metadata() -> None:
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.content = ""
|
||||||
|
doc.title = ""
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.document_metadata = None
|
||||||
|
payload = kb_persistence._doc_revision_payload(doc)
|
||||||
|
assert payload["metadata_before"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
|
||||||
|
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
|
||||||
|
session = _FakeSession()
|
||||||
|
session.execute.return_value = _FakeResult(
|
||||||
|
rows=[
|
||||||
|
MagicMock(content="alpha"),
|
||||||
|
MagicMock(content="beta"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
chunks = await kb_persistence._load_chunks_for_snapshot(
|
||||||
|
session,
|
||||||
|
doc_id=42, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deferred reversibility-flip dispatches.
|
||||||
|
#
|
||||||
|
# The snapshot helpers used to dispatch ``action_log_updated`` directly
|
||||||
|
# from inside the SAVEPOINT block. That meant the SSE side-channel
|
||||||
|
# could tell the UI a row was reversible while the OUTER transaction
|
||||||
|
# was still pending — and if the outer commit failed, every SAVEPOINT
|
||||||
|
# rolled back too, leaving the UI in a state inconsistent with
|
||||||
|
# durable storage. The deferred-dispatch contract fixes that:
|
||||||
|
#
|
||||||
|
# • when a ``deferred_dispatches`` list is provided, the helper
|
||||||
|
# APPENDS the action_id and does NOT dispatch;
|
||||||
|
# • the caller (``commit_staged_filesystem_state``) flushes the list
|
||||||
|
# only AFTER ``await session.commit()`` succeeds; on rollback it
|
||||||
|
# clears the list so nothing is emitted.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _NestedCtx:
|
||||||
|
"""Async context manager mimicking ``session.begin_nested()``."""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> _NestedCtx:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 17
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
deferred: list[int] = []
|
||||||
|
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
|
||||||
|
doc.title = "x.md"
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.content = "body"
|
||||||
|
|
||||||
|
rev_id = await kb_persistence._snapshot_document_pre_write(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
doc=doc,
|
||||||
|
action_id=42,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
deferred_dispatches=deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rev_id == 17
|
||||||
|
# Inline dispatch must NOT have fired; the action_id is queued.
|
||||||
|
assert dispatched == []
|
||||||
|
assert deferred == [42]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 7
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
|
||||||
|
doc.title = "y.md"
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.content = "body"
|
||||||
|
|
||||||
|
await kb_persistence._snapshot_document_pre_write(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
doc=doc,
|
||||||
|
action_id=88,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
# No deferred_dispatches arg — fall back to inline dispatch.
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dispatched == [88]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock() # _mark_action_reversible calls execute
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 3
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
deferred: list[int] = []
|
||||||
|
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
|
||||||
|
|
||||||
|
await kb_persistence._snapshot_folder_pre_mkdir(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
folder=folder,
|
||||||
|
action_id=55,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
deferred_dispatches=deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dispatched == []
|
||||||
|
assert deferred == [55]
|
||||||
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
|
||||||
|
|
||||||
|
The empty-folder marker is critical UX: without it, the LLM cannot
|
||||||
|
distinguish a leaf folder containing one document from a leaf folder
|
||||||
|
that has no descendants at all, and ends up firing ``rmdir`` on
|
||||||
|
non-empty folders. These tests pin the rendering contract so that
|
||||||
|
contract cannot silently regress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
|
||||||
|
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
|
||||||
|
|
||||||
|
|
||||||
|
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
|
||||||
|
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeNonEmptyFolders:
|
||||||
|
def test_folder_with_direct_document_is_non_empty(self):
|
||||||
|
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||||
|
doc_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
|
||||||
|
]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
|
||||||
|
|
||||||
|
def test_truly_empty_leaf_folder_is_not_non_empty(self):
|
||||||
|
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||||
|
doc_paths: list[str] = []
|
||||||
|
assert _compute(folder_paths, doc_paths) == set()
|
||||||
|
|
||||||
|
def test_documents_propagate_up_to_all_ancestors(self):
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/A",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||||
|
]
|
||||||
|
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
assert non_empty == {
|
||||||
|
f"{DOCUMENTS_ROOT}/A",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_chain_with_subfolders_marks_only_leaf_empty(self):
|
||||||
|
# POSIX-like semantic: a folder is "empty" only if it has no
|
||||||
|
# immediate children (docs OR sub-folders). The model needs this
|
||||||
|
# because parallel ``rmdir`` calls all see the same starting state,
|
||||||
|
# so trying to rmdir a parent before its children is never safe.
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/X",
|
||||||
|
f"{DOCUMENTS_ROOT}/X/Y",
|
||||||
|
f"{DOCUMENTS_ROOT}/X/Y/Z",
|
||||||
|
]
|
||||||
|
non_empty = _compute(folder_paths, [])
|
||||||
|
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
|
||||||
|
# sub-folder child, so they are non-empty and should NOT carry the
|
||||||
|
# ``(empty)`` marker.
|
||||||
|
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
|
||||||
|
|
||||||
|
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
|
||||||
|
# Mirrors a real DB layout where every intermediate folder is
|
||||||
|
# materialized in the ``folders`` table.
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel",
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Notes",
|
||||||
|
]
|
||||||
|
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
|
||||||
|
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatTreeRendering:
|
||||||
|
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
|
||||||
|
|
||||||
|
def _render(
|
||||||
|
self,
|
||||||
|
folder_paths: list[str],
|
||||||
|
doc_specs: list[dict],
|
||||||
|
) -> str:
|
||||||
|
from app.agents.new_chat.path_resolver import PathIndex
|
||||||
|
|
||||||
|
index = PathIndex(
|
||||||
|
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
|
||||||
|
)
|
||||||
|
|
||||||
|
class _Row:
|
||||||
|
def __init__(self, **kw):
|
||||||
|
self.__dict__.update(kw)
|
||||||
|
|
||||||
|
docs = [_Row(**spec) for spec in doc_specs]
|
||||||
|
|
||||||
|
mw = KnowledgeTreeMiddleware(
|
||||||
|
search_space_id=1,
|
||||||
|
filesystem_mode=None, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return mw._format_tree(index, docs)
|
||||||
|
|
||||||
|
def test_renders_empty_marker_only_for_truly_empty_folders(self):
|
||||||
|
# Reproduces the failure scenario from the bug report:
|
||||||
|
# ``Boarding Pass`` is empty (its only doc was just deleted), while
|
||||||
|
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
|
||||||
|
# folders are present in the index, mirroring the real DB layout.
|
||||||
|
folder_paths = [
|
||||||
|
"/documents/File Upload",
|
||||||
|
"/documents/File Upload/2026-04-08",
|
||||||
|
"/documents/File Upload/2026-04-08/Travel",
|
||||||
|
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
|
||||||
|
"/documents/File Upload/2026-04-15",
|
||||||
|
"/documents/File Upload/2026-04-15/Finance",
|
||||||
|
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
|
||||||
|
]
|
||||||
|
tax_returns_folder_id = (
|
||||||
|
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
rendered = self._render(
|
||||||
|
folder_paths=folder_paths,
|
||||||
|
doc_specs=[
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"title": "federal.pdf",
|
||||||
|
"folder_id": tax_returns_folder_id,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Boarding Pass/ (empty)" in rendered
|
||||||
|
assert "Tax Returns/ (empty)" not in rendered
|
||||||
|
# Intermediate ancestors of the doc must NOT be marked empty.
|
||||||
|
assert "Finance/ (empty)" not in rendered
|
||||||
|
assert "2026-04-15/ (empty)" not in rendered
|
||||||
|
|
@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
|
||||||
assert write.error is not None
|
assert write.error is not None
|
||||||
assert "parent directory" in write.error
|
assert "parent directory" in write.error
|
||||||
assert not (tmp_path / "tempoo").exists()
|
assert not (tmp_path / "tempoo").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_success(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "delete-me.md").write_text("bye")
|
||||||
|
|
||||||
|
res = backend.delete_file("/delete-me.md")
|
||||||
|
assert res.error is None
|
||||||
|
assert res.path == "/delete-me.md"
|
||||||
|
assert not (tmp_path / "delete-me.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "subdir").mkdir()
|
||||||
|
|
||||||
|
res = backend.delete_file("/subdir")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "directory" in res.error
|
||||||
|
assert (tmp_path / "subdir").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
res = backend.delete_file("/nope.md")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not found" in res.error
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_success(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "empty").mkdir()
|
||||||
|
|
||||||
|
res = backend.rmdir("/empty")
|
||||||
|
assert res.error is None
|
||||||
|
assert res.path == "/empty"
|
||||||
|
assert not (tmp_path / "empty").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "withkid").mkdir()
|
||||||
|
(tmp_path / "withkid" / "child.md").write_text("x")
|
||||||
|
|
||||||
|
res = backend.rmdir("/withkid")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not empty" in res.error
|
||||||
|
assert (tmp_path / "withkid" / "child.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "f.md").write_text("x")
|
||||||
|
|
||||||
|
res = backend.rmdir("/f.md")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not a directory" in res.error
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
|
||||||
|
"""``rmdir /`` MUST fail. The exact error wording comes from
|
||||||
|
``_resolve_virtual`` (root resolves to outside the sandbox); what
|
||||||
|
matters is that the call returns an error and does NOT delete the
|
||||||
|
sandbox root on disk."""
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
res = backend.rmdir("/")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "Invalid path" in res.error or "root" in res.error
|
||||||
|
assert tmp_path.exists()
|
||||||
|
|
|
||||||
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
|
||||||
|
|
||||||
|
The regenerate route's edit-from-position path introduces:
|
||||||
|
* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples
|
||||||
|
newest-first and picks the first one whose ``metadata["turn_id"]``
|
||||||
|
differs from the edited turn. That checkpoint is the rewind target
|
||||||
|
(state immediately before the edited turn started).
|
||||||
|
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
|
||||||
|
with a validator that prevents callers from requesting a revert pass
|
||||||
|
without specifying which turn to roll back.
|
||||||
|
|
||||||
|
These are pure-Python helpers that don't need a live DB, so we exercise
|
||||||
|
them with a small ``CheckpointTuple``-shaped namespace and direct
|
||||||
|
schema instantiation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
|
||||||
|
from app.schemas.new_chat import RegenerateRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
|
||||||
|
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
config={"configurable": {"checkpoint_id": checkpoint_id}},
|
||||||
|
metadata={"turn_id": turn_id} if turn_id is not None else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindPreTurnCheckpointId:
|
||||||
|
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
|
||||||
|
# Newest-first: T2 is the most-recent turn. The latest non-T2
|
||||||
|
# checkpoint (cp2) is the rewind target — state immediately
|
||||||
|
# before T2 began.
|
||||||
|
tuples = [
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
|
||||||
|
# Regression for the bug where walking newest-first returned the
|
||||||
|
# FIRST cp with ``turn_id != target`` — which is one of the
|
||||||
|
# later-turn checkpoints, NOT the pre-turn boundary. Editing
|
||||||
|
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
|
||||||
|
# latest T3 checkpoint (cp6).
|
||||||
|
tuples = [
|
||||||
|
_cp("cp6", "T3"),
|
||||||
|
_cp("cp5", "T3"),
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_returns_none_when_editing_first_turn(self) -> None:
|
||||||
|
# No pre-turn boundary exists; caller is expected to fall back
|
||||||
|
# to the oldest checkpoint or special-case "first turn of the
|
||||||
|
# thread".
|
||||||
|
tuples = [
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
|
||||||
|
|
||||||
|
def test_returns_none_when_only_edited_turn_present(self) -> None:
|
||||||
|
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
|
||||||
|
|
||||||
|
def test_returns_none_for_empty_history(self) -> None:
|
||||||
|
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
|
||||||
|
|
||||||
|
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
|
||||||
|
# Checkpoints written before migration 136 have no
|
||||||
|
# ``metadata.turn_id``. They should be eligible rewind targets
|
||||||
|
# — they came before the
|
||||||
|
# edited turn began.
|
||||||
|
tuples = [
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
SimpleNamespace(
|
||||||
|
config={"configurable": {"checkpoint_id": "cp2"}},
|
||||||
|
metadata=None,
|
||||||
|
),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
|
||||||
|
# then cp3(T2) crosses the boundary -> return cp2.
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
|
||||||
|
# If a checkpoint tuple's ``config["configurable"]`` is missing
|
||||||
|
# the ``checkpoint_id`` key (corrupt / partial), we keep the
|
||||||
|
# last known good target instead of crashing.
|
||||||
|
broken = SimpleNamespace(
|
||||||
|
config={"configurable": {}}, metadata={"turn_id": "T1"}
|
||||||
|
)
|
||||||
|
tuples = [
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
broken,
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegenerateRequestValidation:
|
||||||
|
def test_revert_actions_requires_from_message_id(self) -> None:
|
||||||
|
with pytest.raises(Exception) as exc:
|
||||||
|
RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
revert_actions=True,
|
||||||
|
)
|
||||||
|
msg = str(exc.value).lower()
|
||||||
|
assert "from_message_id" in msg
|
||||||
|
|
||||||
|
def test_from_message_id_without_revert_is_allowed(self) -> None:
|
||||||
|
req = RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
from_message_id=42,
|
||||||
|
)
|
||||||
|
assert req.from_message_id == 42
|
||||||
|
assert req.revert_actions is False
|
||||||
|
|
||||||
|
def test_revert_actions_with_from_message_id_passes(self) -> None:
|
||||||
|
req = RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
from_message_id=42,
|
||||||
|
revert_actions=True,
|
||||||
|
)
|
||||||
|
assert req.revert_actions is True
|
||||||
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
|
|
@ -0,0 +1,530 @@
|
||||||
|
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
|
||||||
|
The per-turn batch revert route walks rows in reverse ``created_at``
|
||||||
|
order, reverts each independently, and returns a per-action result
|
||||||
|
list. Partial success is normal — the response status
|
||||||
|
is ``"partial"`` whenever any row could not be reverted, but we never
|
||||||
|
collapse the whole batch into a 4xx.
|
||||||
|
|
||||||
|
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
|
||||||
|
session, so they exercise the route's dispatch logic without a real DB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.routes import agent_revert_route
|
||||||
|
from app.services.revert_service import RevertOutcome
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeAction:
|
||||||
|
id: int
|
||||||
|
tool_name: str
|
||||||
|
user_id: str | None = "u1"
|
||||||
|
reverse_of: int | None = None
|
||||||
|
error: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeUser:
|
||||||
|
id: str = "u1"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ScalarResult:
|
||||||
|
rows: list[Any]
|
||||||
|
|
||||||
|
def first(self) -> Any:
|
||||||
|
return self.rows[0] if self.rows else None
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Result:
|
||||||
|
rows: list[Any] = field(default_factory=list)
|
||||||
|
|
||||||
|
def scalars(self) -> _ScalarResult:
|
||||||
|
return _ScalarResult(self.rows)
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
|
||||||
|
# the row-tuple result (no ``.scalars()`` indirection). The
|
||||||
|
# rows queued for that helper are list[(revert_id, original_id)].
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeNestedCtx:
|
||||||
|
"""Async context manager that mimics ``session.begin_nested()``.
|
||||||
|
|
||||||
|
The route raises a sentinel exception inside this block to roll back
|
||||||
|
bad rows. We just pass the exception through.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> _FakeNestedCtx:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
# Returning False (or None) propagates the exception; the route
|
||||||
|
# catches its own sentinel above this layer.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Minimal AsyncSession stand-in for the revert-turn route.
|
||||||
|
|
||||||
|
Holds a queue of result objects; each ``execute(...)`` pops the next
|
||||||
|
one. The route calls ``execute`` exactly once per query so this maps
|
||||||
|
cleanly onto the assertion order of the test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._results: list[_Result] = []
|
||||||
|
self.committed = False
|
||||||
|
self.rolled_back = False
|
||||||
|
# Count execute() calls to assert "no N+1 reverts".
|
||||||
|
self.execute_call_count = 0
|
||||||
|
|
||||||
|
def queue(self, *results: _Result) -> None:
|
||||||
|
self._results.extend(results)
|
||||||
|
|
||||||
|
async def execute(self, _stmt: Any) -> _Result:
|
||||||
|
self.execute_call_count += 1
|
||||||
|
if not self._results:
|
||||||
|
return _Result(rows=[])
|
||||||
|
return self._results.pop(0)
|
||||||
|
|
||||||
|
def begin_nested(self) -> _FakeNestedCtx:
|
||||||
|
return _FakeNestedCtx()
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.committed = True
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rolled_back = True
|
||||||
|
|
||||||
|
|
||||||
|
def _enabled_flags() -> AgentFeatureFlags:
|
||||||
|
return AgentFeatureFlags(
|
||||||
|
disable_new_agent_stack=False,
|
||||||
|
enable_action_log=True,
|
||||||
|
enable_revert_route=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_get_flags():
|
||||||
|
def _patch(flags: AgentFeatureFlags):
|
||||||
|
return patch(
|
||||||
|
"app.routes.agent_revert_route.get_flags",
|
||||||
|
return_value=flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlagGuard:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_503_when_revert_route_disabled(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
flags = AgentFeatureFlags(
|
||||||
|
disable_new_agent_stack=False,
|
||||||
|
enable_action_log=True,
|
||||||
|
enable_revert_route=False,
|
||||||
|
)
|
||||||
|
session = _FakeSession()
|
||||||
|
with patch_get_flags(flags), pytest.raises(Exception) as exc:
|
||||||
|
await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="42:1700000000000",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert getattr(exc.value, "status_code", None) == 503
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevertTurnDispatch:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=[])) # rows query returns nothing
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-empty",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.total == 0
|
||||||
|
assert response.results == []
|
||||||
|
assert session.committed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_walks_rows_in_reverse_and_reverts_each(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=10, tool_name="rm"),
|
||||||
|
_FakeAction(id=9, tool_name="write_file"),
|
||||||
|
_FakeAction(id=8, tool_name="mkdir"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched ``_was_already_reverted_batch`` probe replaces
|
||||||
|
# the previous N per-row SELECTs.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"reverted-{action.id}",
|
||||||
|
new_action_id=100 + action.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-3",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.total == 3
|
||||||
|
assert response.reverted == 3
|
||||||
|
assert [r.action_id for r in response.results] == [10, 9, 8]
|
||||||
|
assert all(r.status == "reverted" for r in response.results)
|
||||||
|
assert response.results[0].new_action_id == 110
|
||||||
|
# Only TWO ``execute`` calls regardless of the row count: one
|
||||||
|
# for the rows query, one for the batched
|
||||||
|
# ``_was_already_reverted_batch`` probe. Regression guard
|
||||||
|
# against re-introducing the per-row N+1 lookup.
|
||||||
|
assert session.execute_call_count == 2, (
|
||||||
|
"revert-turn loop must batch idempotency probes; got "
|
||||||
|
f"{session.execute_call_count} execute() calls (expected 2)."
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_already_reverted_rows_are_marked_idempotent(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=5, tool_name="edit_file")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch probe returns ``[(revert_id, original_id)]``.
|
||||||
|
session.queue(_Result(rows=[(42, 5)]))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-i",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.results[0].status == "already_reverted"
|
||||||
|
assert response.results[0].new_action_id == 42
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revert_action_skips_existing_revert_rows(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-rev",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.results[0].status == "skipped"
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_success_when_some_rows_not_reversible(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=2, tool_name="send_email"),
|
||||||
|
_FakeAction(id=1, tool_name="edit_file"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched idempotency probe.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.tool_name == "send_email":
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="connector revert not yet implemented",
|
||||||
|
)
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-mix",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.reverted == 1
|
||||||
|
assert response.not_reversible == 1
|
||||||
|
statuses = sorted(r.status for r in response.results)
|
||||||
|
assert statuses == ["not_reversible", "reverted"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unexpected_exception_marks_row_failed_not_batch(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=20, tool_name="edit_file"),
|
||||||
|
_FakeAction(id=21, tool_name="edit_file"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched idempotency probe.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.id == 20:
|
||||||
|
raise RuntimeError("disk on fire")
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=999)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-fail",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.failed == 1
|
||||||
|
assert response.reverted == 1
|
||||||
|
bad = next(r for r in response.results if r.action_id == 20)
|
||||||
|
assert bad.status == "failed"
|
||||||
|
assert "disk on fire" in (bad.error or "")
|
||||||
|
good = next(r for r in response.results if r.action_id == 21)
|
||||||
|
assert good.status == "reverted"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied_when_other_user_owns_action(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch idempotency probe (no prior reverts).
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-perm",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(id="not-owner"),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.results[0].status == "permission_denied"
|
||||||
|
# ``permission_denied`` has its own dedicated counter so the
|
||||||
|
# response invariant ``total == sum(counters)`` always holds
|
||||||
|
# without overloading ``not_reversible`` (which historically
|
||||||
|
# absorbed this case and confused frontend toasts).
|
||||||
|
assert response.permission_denied == 1
|
||||||
|
assert response.not_reversible == 0
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_counter_invariant_holds_across_mixed_outcomes(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
"""Every row is accounted for in EXACTLY ONE counter.
|
||||||
|
|
||||||
|
Mixes one of every supported outcome (reverted, already_reverted,
|
||||||
|
not_reversible, permission_denied, failed, skipped) and asserts
|
||||||
|
that the sum of counters equals ``response.total``.
|
||||||
|
"""
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=10, tool_name="edit_file"), # ok
|
||||||
|
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
|
||||||
|
_FakeAction(id=8, tool_name="send_email"), # not_reversible
|
||||||
|
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
|
||||||
|
_FakeAction(id=6, tool_name="edit_file"), # failed
|
||||||
|
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched probe; only id=9 has a prior revert.
|
||||||
|
# Schema: list[(revert_id, original_id)].
|
||||||
|
session.queue(_Result(rows=[(42, 9)]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.id == 10:
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||||
|
if action.id == 8:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="connector revert not yet implemented",
|
||||||
|
)
|
||||||
|
if action.id == 6:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
raise AssertionError(f"unexpected revert call for {action.id}")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route,
|
||||||
|
"revert_action",
|
||||||
|
AsyncMock(side_effect=_fake_revert),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-mixed-all",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(), # only id=7 has a different user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.total == len(rows) == 6
|
||||||
|
bucket_sum = (
|
||||||
|
response.reverted
|
||||||
|
+ response.already_reverted
|
||||||
|
+ response.not_reversible
|
||||||
|
+ response.permission_denied
|
||||||
|
+ response.failed
|
||||||
|
+ response.skipped
|
||||||
|
)
|
||||||
|
assert bucket_sum == response.total, (
|
||||||
|
"Counter invariant broken: total "
|
||||||
|
f"({response.total}) != sum of counters ({bucket_sum}). "
|
||||||
|
f"Counters: reverted={response.reverted}, "
|
||||||
|
f"already_reverted={response.already_reverted}, "
|
||||||
|
f"not_reversible={response.not_reversible}, "
|
||||||
|
f"permission_denied={response.permission_denied}, "
|
||||||
|
f"failed={response.failed}, skipped={response.skipped}"
|
||||||
|
)
|
||||||
|
assert response.reverted == 1
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.not_reversible == 1
|
||||||
|
assert response.permission_denied == 1
|
||||||
|
assert response.failed == 1
|
||||||
|
assert response.skipped == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_integrity_error_translates_to_already_reverted(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
"""The partial unique index on ``reverse_of`` raises
|
||||||
|
``IntegrityError`` when a concurrent revert wins the race against
|
||||||
|
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
|
||||||
|
recover by re-querying for the winning revert id and returning
|
||||||
|
``status="already_reverted"`` (not ``"failed"``) so racing
|
||||||
|
clients see consistent idempotent semantics.
|
||||||
|
"""
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
rows = [_FakeAction(id=33, tool_name="edit_file")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch pre-flight probe: nothing yet (we'll race).
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
# Post-IntegrityError fallback uses the SCALAR
|
||||||
|
# ``_was_already_reverted`` (single-id lookup) so it pulls
|
||||||
|
# ``[777]`` via ``.scalars().first()``.
|
||||||
|
session.queue(_Result(rows=[777]))
|
||||||
|
|
||||||
|
async def _racing_revert(_session, *, action, requester_user_id):
|
||||||
|
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route,
|
||||||
|
"revert_action",
|
||||||
|
AsyncMock(side_effect=_racing_revert),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-race",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.failed == 0, (
|
||||||
|
"IntegrityError must NOT surface as a failed row; the unique "
|
||||||
|
"index is the durable expression of idempotency."
|
||||||
|
)
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.results[0].status == "already_reverted"
|
||||||
|
assert response.results[0].new_action_id == 777
|
||||||
|
|
@ -0,0 +1,370 @@
|
||||||
|
"""Unit tests for the filesystem-tool branches of ``revert_service``.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
|
||||||
|
* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document
|
||||||
|
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
|
||||||
|
prefix-based dispatch).
|
||||||
|
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
|
||||||
|
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
|
||||||
|
when ``metadata_before["virtual_path"]`` is missing.
|
||||||
|
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
|
||||||
|
document.
|
||||||
|
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
|
||||||
|
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
|
||||||
|
when the folder gained children.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services import revert_service
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
revert_service,
|
||||||
|
"embed_texts",
|
||||||
|
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
self._scalar = scalar
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self._rows)
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> Any:
|
||||||
|
return self._scalar
|
||||||
|
|
||||||
|
def scalars(self) -> Any:
|
||||||
|
return _FakeScalarsProxy(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScalarsProxy:
|
||||||
|
def __init__(self, rows: list[Any]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def first(self) -> Any:
|
||||||
|
return self._rows[0] if self._rows else None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.execute = AsyncMock()
|
||||||
|
self.added: list[Any] = []
|
||||||
|
self.deleted: list[Any] = []
|
||||||
|
self.flush = AsyncMock()
|
||||||
|
# session.get(Model, pk) lookup
|
||||||
|
self.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
async def _flush_assigning_ids() -> None:
|
||||||
|
for obj in self.added:
|
||||||
|
if getattr(obj, "id", None) is None:
|
||||||
|
obj.id = 999
|
||||||
|
|
||||||
|
self.flush.side_effect = _flush_assigning_ids
|
||||||
|
|
||||||
|
def add(self, obj: Any) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
|
||||||
|
def add_all(self, objs: list[Any]) -> None:
|
||||||
|
self.added.extend(objs)
|
||||||
|
|
||||||
|
|
||||||
|
def _action(*, tool_name: str, action_id: int = 7):
|
||||||
|
return MagicMock(
|
||||||
|
id=action_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=2,
|
||||||
|
user_id="user-1",
|
||||||
|
reverse_descriptor=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _doc_revision(
|
||||||
|
*,
|
||||||
|
document_id: int | None = None,
|
||||||
|
content_before: str | None = "old content",
|
||||||
|
title_before: str | None = "notes.md",
|
||||||
|
folder_id_before: int | None = 5,
|
||||||
|
chunks_before: list[dict[str, str]] | None = None,
|
||||||
|
metadata_before: dict[str, str] | None = None,
|
||||||
|
):
|
||||||
|
revision = MagicMock()
|
||||||
|
revision.id = 100
|
||||||
|
revision.document_id = document_id
|
||||||
|
revision.search_space_id = 2
|
||||||
|
revision.content_before = content_before
|
||||||
|
revision.title_before = title_before
|
||||||
|
revision.folder_id_before = folder_id_before
|
||||||
|
revision.chunks_before = chunks_before or []
|
||||||
|
revision.metadata_before = metadata_before
|
||||||
|
return revision
|
||||||
|
|
||||||
|
|
||||||
|
def _folder_revision(
|
||||||
|
*,
|
||||||
|
folder_id: int | None = None,
|
||||||
|
name_before: str | None = "team",
|
||||||
|
parent_id_before: int | None = None,
|
||||||
|
position_before: str | None = "a0",
|
||||||
|
):
|
||||||
|
revision = MagicMock()
|
||||||
|
revision.id = 200
|
||||||
|
revision.folder_id = folder_id
|
||||||
|
revision.search_space_id = 2
|
||||||
|
revision.name_before = name_before
|
||||||
|
revision.parent_id_before = parent_id_before
|
||||||
|
revision.position_before = position_before
|
||||||
|
return revision
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exact-name dispatch regression guards
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExactDispatch:
|
||||||
|
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rmdir_does_not_misroute_to_document(self) -> None:
|
||||||
|
# If dispatch used `startswith("rm")` we'd hit the document branch
|
||||||
|
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
|
||||||
|
session = _FakeSession()
|
||||||
|
action = _action(tool_name="rmdir")
|
||||||
|
# No folder revisions exist for this action.
|
||||||
|
session.execute.return_value = _FakeResult(rows=[])
|
||||||
|
outcome = await revert_service.revert_action(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
action=action,
|
||||||
|
requester_user_id="user-1",
|
||||||
|
)
|
||||||
|
assert outcome.status == "not_reversible"
|
||||||
|
assert "folder_revisions" in outcome.message
|
||||||
|
|
||||||
|
def test_dispatch_sets_split_doc_and_folder(self) -> None:
|
||||||
|
# Static guards on the dispatch tables themselves so a future
|
||||||
|
# refactor doesn't accidentally reintroduce the prefix bug.
|
||||||
|
assert "rm" in revert_service._DOC_TOOLS
|
||||||
|
assert "rmdir" in revert_service._FOLDER_TOOLS
|
||||||
|
assert "rmdir" not in revert_service._DOC_TOOLS
|
||||||
|
assert "rm" not in revert_service._FOLDER_TOOLS
|
||||||
|
# ``move_file`` lives only in document tools (it's a doc rename).
|
||||||
|
assert "move_file" in revert_service._DOC_TOOLS
|
||||||
|
assert "move_file" not in revert_service._FOLDER_TOOLS
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rm revert (re-INSERT)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_inserts_document_with_chunks(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None, # row was hard-deleted
|
||||||
|
content_before="hello world",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
|
||||||
|
metadata_before={"virtual_path": "/documents/x.md"},
|
||||||
|
)
|
||||||
|
# No collision check hit and the resulting query returns nothing.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# New Document + 2 chunks must have been added.
|
||||||
|
from app.db import Chunk, Document
|
||||||
|
|
||||||
|
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
|
||||||
|
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
|
||||||
|
assert len(added_docs) == 1
|
||||||
|
assert added_docs[0].title == "x.md"
|
||||||
|
assert len(added_chunks) == 2
|
||||||
|
# Snapshot was repointed at the new doc id so a follow-up revert works.
|
||||||
|
assert revision.document_id == added_docs[0].id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
# Snapshot with NO metadata_before — the fallback path must kick in.
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hello",
|
||||||
|
title_before="cap.md",
|
||||||
|
folder_id_before=42,
|
||||||
|
chunks_before=[],
|
||||||
|
metadata_before=None,
|
||||||
|
)
|
||||||
|
# session.get(Folder, 42) returns a folder with a name.
|
||||||
|
folder = MagicMock()
|
||||||
|
folder.name = "team"
|
||||||
|
folder.parent_id = None
|
||||||
|
# First .get is for the folder lookup in the path-derivation.
|
||||||
|
session.get = AsyncMock(return_value=folder)
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_root_path_when_no_folder(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""metadata_before is None and folder_id_before is None still
|
||||||
|
resolves: title fallback yields ``/documents/<title>`` so revert
|
||||||
|
proceeds at the root of the documents tree."""
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hello",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
metadata_before=None,
|
||||||
|
)
|
||||||
|
# No collision in the documents tree at /documents/x.md.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hi",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
metadata_before={"virtual_path": "/documents/x.md"},
|
||||||
|
)
|
||||||
|
# SELECT for unique_identifier_hash collision hits an existing row.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=42)
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "tool_unavailable"
|
||||||
|
assert "collide" in outcome.message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# write_file create revert (DELETE)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteFileCreateRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_created_doc(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=99,
|
||||||
|
content_before=None, # marker for "created in this action"
|
||||||
|
title_before=None,
|
||||||
|
)
|
||||||
|
outcome = await revert_service._delete_created_document(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# Exactly one DELETE was issued.
|
||||||
|
assert session.execute.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rmdir revert (re-INSERT folder)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmdirRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_inserts_folder_from_snapshot(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(
|
||||||
|
folder_id=None,
|
||||||
|
name_before="team",
|
||||||
|
parent_id_before=None,
|
||||||
|
position_before="a0",
|
||||||
|
)
|
||||||
|
outcome = await revert_service._reinsert_folder_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
from app.db import Folder
|
||||||
|
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
|
||||||
|
assert len(added_folders) == 1
|
||||||
|
assert added_folders[0].name == "team"
|
||||||
|
assert revision.folder_id == added_folders[0].id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# mkdir revert (DELETE folder)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMkdirRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_empty_folder(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(folder_id=42)
|
||||||
|
# Both the doc-existence check and the child-folder check return None.
|
||||||
|
session.execute.side_effect = [
|
||||||
|
_FakeResult(scalar=None), # docs
|
||||||
|
_FakeResult(scalar=None), # children
|
||||||
|
_FakeResult(scalar=None), # delete (no return value)
|
||||||
|
]
|
||||||
|
outcome = await revert_service._delete_created_folder(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# 3 executes: docs check, children check, delete.
|
||||||
|
assert session.execute.await_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(folder_id=42)
|
||||||
|
# First check (docs) returns "row found".
|
||||||
|
session.execute.return_value = _FakeResult(scalar=1)
|
||||||
|
outcome = await revert_service._delete_created_folder(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "tool_unavailable"
|
||||||
|
assert "no longer empty" in outcome.message
|
||||||
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
|
||||||
|
|
||||||
|
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
||||||
|
silently dropped every other shape (Anthropic typed-block lists,
|
||||||
|
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
|
||||||
|
a few providers). These regression tests pin those four shapes plus the
|
||||||
|
defensive cases (``None`` chunk, mixed types, missing fields).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeChunk:
|
||||||
|
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
|
||||||
|
|
||||||
|
content: Any = ""
|
||||||
|
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStringContent:
|
||||||
|
def test_plain_string_content_extracts_as_text(self) -> None:
|
||||||
|
chunk = _FakeChunk(content="hello world")
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "hello world"
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
def test_empty_string_content_yields_empty_text(self) -> None:
|
||||||
|
chunk = _FakeChunk(content="")
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == ""
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestListContent:
|
||||||
|
def test_list_of_text_blocks_concatenates(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Hello "},
|
||||||
|
{"type": "text", "text": "world"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "Hello world"
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
|
||||||
|
def test_mixed_text_and_reasoning_blocks(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "reasoning", "reasoning": "Let me think... "},
|
||||||
|
{"type": "reasoning", "text": "still thinking."},
|
||||||
|
{"type": "text", "text": "The answer is 42."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "The answer is 42."
|
||||||
|
assert out["reasoning"] == "Let me think... still thinking."
|
||||||
|
|
||||||
|
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Calling tool..."},
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"id": "call_123",
|
||||||
|
"name": "make_widget",
|
||||||
|
"args": '{"color":"red"}',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "Calling tool..."
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
assert out["tool_call_chunks"][0]["id"] == "call_123"
|
||||||
|
assert out["tool_call_chunks"][0]["name"] == "make_widget"
|
||||||
|
|
||||||
|
def test_tool_use_blocks_also_extracted(self) -> None:
|
||||||
|
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_xyz",
|
||||||
|
"name": "search",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["tool_call_chunks"] == [
|
||||||
|
{"type": "tool_use", "id": "call_xyz", "name": "search"}
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_unknown_block_types_are_ignored(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "image_url", "url": "https://example.com/x.png"},
|
||||||
|
{"type": "text", "text": "ok"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "ok"
|
||||||
|
|
||||||
|
def test_blocks_without_text_field_are_ignored(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text"}, # no text/content key
|
||||||
|
{"type": "text", "text": "kept"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "kept"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdditionalKwargsReasoning:
|
||||||
|
def test_reasoning_content_in_additional_kwargs(self) -> None:
|
||||||
|
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content="visible answer",
|
||||||
|
additional_kwargs={"reasoning_content": "internal monologue"},
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "visible answer"
|
||||||
|
assert out["reasoning"] == "internal monologue"
|
||||||
|
|
||||||
|
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[{"type": "reasoning", "text": "from blocks. "}],
|
||||||
|
additional_kwargs={"reasoning_content": "from kwargs."},
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["reasoning"] == "from blocks. from kwargs."
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallChunksAttribute:
|
||||||
|
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content="streaming text",
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "streaming text"
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
assert out["tool_call_chunks"][0]["id"] == "tc-9"
|
||||||
|
|
||||||
|
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"id": "from-block",
|
||||||
|
"name": "x",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
|
||||||
|
assert ids == ["from-block", "from-attr"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefensive:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"chunk_value",
|
||||||
|
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
|
||||||
|
)
|
||||||
|
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
|
||||||
|
out = _extract_chunk_parts(chunk_value)
|
||||||
|
assert out["text"] == ""
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
@ -14,6 +14,13 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
|
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
|
||||||
|
import {
|
||||||
|
agentActionsByChatTurnIdAtom,
|
||||||
|
markAgentActionRevertedAtom,
|
||||||
|
resetAgentActionMapAtom,
|
||||||
|
updateAgentActionReversibleAtom,
|
||||||
|
upsertAgentActionAtom,
|
||||||
|
} from "@/atoms/chat/agent-actions.atom";
|
||||||
import {
|
import {
|
||||||
clearTargetCommentIdAtom,
|
clearTargetCommentIdAtom,
|
||||||
currentThreadAtom,
|
currentThreadAtom,
|
||||||
|
|
@ -36,6 +43,11 @@ import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||||
import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
||||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||||
|
import {
|
||||||
|
EditMessageDialog,
|
||||||
|
type EditMessageDialogChoice,
|
||||||
|
} from "@/components/assistant-ui/edit-message-dialog";
|
||||||
|
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
|
||||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||||
import { Thread } from "@/components/assistant-ui/thread";
|
import { Thread } from "@/components/assistant-ui/thread";
|
||||||
import {
|
import {
|
||||||
|
|
@ -55,14 +67,19 @@ import {
|
||||||
setActivePodcastTaskId,
|
setActivePodcastTaskId,
|
||||||
} from "@/lib/chat/podcast-state";
|
} from "@/lib/chat/podcast-state";
|
||||||
import {
|
import {
|
||||||
|
addStepSeparator,
|
||||||
addToolCall,
|
addToolCall,
|
||||||
|
appendReasoning,
|
||||||
appendText,
|
appendText,
|
||||||
buildContentForPersistence,
|
buildContentForPersistence,
|
||||||
buildContentForUI,
|
buildContentForUI,
|
||||||
type ContentPartsState,
|
type ContentPartsState,
|
||||||
|
endReasoning,
|
||||||
FrameBatchedUpdater,
|
FrameBatchedUpdater,
|
||||||
|
findToolCallIdByLcId,
|
||||||
readSSEStream,
|
readSSEStream,
|
||||||
type ThinkingStepData,
|
type ThinkingStepData,
|
||||||
|
type ToolUIGate,
|
||||||
updateThinkingSteps,
|
updateThinkingSteps,
|
||||||
updateToolCall,
|
updateToolCall,
|
||||||
} from "@/lib/chat/streaming-state";
|
} from "@/lib/chat/streaming-state";
|
||||||
|
|
@ -161,44 +178,38 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tools that should render custom UI in the chat.
|
* Every tool call renders a card. The legacy
|
||||||
|
* ``BASE_TOOLS_WITH_UI`` allowlist used to drop unknown tool calls on the
|
||||||
|
* floor; we now route everything through ``ToolFallback``. Persisted
|
||||||
|
* payload size stays bounded because the backend's
|
||||||
|
* ``format_thinking_step`` summarisation and the
|
||||||
|
* ``result_length``-only default for unknown tools (see
|
||||||
|
* ``stream_new_chat.py``) keep the JSON from ballooning.
|
||||||
*/
|
*/
|
||||||
const BASE_TOOLS_WITH_UI = new Set([
|
const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
|
||||||
"web_search",
|
|
||||||
"generate_podcast",
|
/**
|
||||||
"generate_report",
|
* When a streamed message is persisted, the backend returns the durable
|
||||||
"generate_resume",
|
* ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it
|
||||||
"generate_video_presentation",
|
* into the assistant-ui message metadata so the per-turn "Revert turn"
|
||||||
"display_image",
|
* button can scope to this turn's actions even after a full chat reload.
|
||||||
"generate_image",
|
*/
|
||||||
"delete_notion_page",
|
function mergeChatTurnIdIntoMessage(
|
||||||
"create_notion_page",
|
msg: ThreadMessageLike,
|
||||||
"update_notion_page",
|
turnId: string | null | undefined
|
||||||
"create_linear_issue",
|
): ThreadMessageLike {
|
||||||
"update_linear_issue",
|
if (!turnId) return msg;
|
||||||
"delete_linear_issue",
|
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
|
||||||
"create_google_drive_file",
|
const existingCustom = existingMeta.custom ?? {};
|
||||||
"delete_google_drive_file",
|
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
|
||||||
"create_onedrive_file",
|
return {
|
||||||
"delete_onedrive_file",
|
...msg,
|
||||||
"create_dropbox_file",
|
metadata: {
|
||||||
"delete_dropbox_file",
|
...existingMeta,
|
||||||
"create_calendar_event",
|
custom: { ...existingCustom, chatTurnId: turnId },
|
||||||
"update_calendar_event",
|
},
|
||||||
"delete_calendar_event",
|
};
|
||||||
"create_gmail_draft",
|
}
|
||||||
"update_gmail_draft",
|
|
||||||
"send_gmail_email",
|
|
||||||
"trash_gmail_email",
|
|
||||||
"create_jira_issue",
|
|
||||||
"update_jira_issue",
|
|
||||||
"delete_jira_issue",
|
|
||||||
"create_confluence_page",
|
|
||||||
"update_confluence_page",
|
|
||||||
"delete_confluence_page",
|
|
||||||
"execute",
|
|
||||||
// "write_todos", // Disabled for now
|
|
||||||
]);
|
|
||||||
|
|
||||||
export default function NewChatPage() {
|
export default function NewChatPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
|
@ -215,7 +226,7 @@ export default function NewChatPage() {
|
||||||
assistantMsgId: string;
|
assistantMsgId: string;
|
||||||
interruptData: Record<string, unknown>;
|
interruptData: Record<string, unknown>;
|
||||||
} | null>(null);
|
} | null>(null);
|
||||||
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
|
const toolsWithUI = TOOLS_WITH_UI_ALL;
|
||||||
|
|
||||||
// Get disabled tools from the tool toggle UI
|
// Get disabled tools from the tool toggle UI
|
||||||
const disabledTools = useAtomValue(disabledToolsAtom);
|
const disabledTools = useAtomValue(disabledToolsAtom);
|
||||||
|
|
@ -235,6 +246,25 @@ export default function NewChatPage() {
|
||||||
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
||||||
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
|
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
|
||||||
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
|
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
|
||||||
|
// Agent action log SSE side-channel.
|
||||||
|
const upsertAgentAction = useSetAtom(upsertAgentActionAtom);
|
||||||
|
const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom);
|
||||||
|
const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom);
|
||||||
|
const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom);
|
||||||
|
// Chat-turn-keyed action map for the edit-from-position pre-flight
|
||||||
|
// that decides whether to show the confirmation dialog.
|
||||||
|
const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom);
|
||||||
|
// Edit dialog state. Holds the message id being edited and
|
||||||
|
// the (already extracted) regenerate args so we can resume the edit
|
||||||
|
// after the user picks "revert all" / "continue" / "cancel".
|
||||||
|
const [editDialogState, setEditDialogState] = useState<{
|
||||||
|
fromMessageId: number;
|
||||||
|
userQuery: string | null;
|
||||||
|
userMessageContent: ThreadMessageLike["content"];
|
||||||
|
userImages: NewChatUserImagePayload[];
|
||||||
|
downstreamReversibleCount: number;
|
||||||
|
downstreamTotalCount: number;
|
||||||
|
} | null>(null);
|
||||||
|
|
||||||
// Get current user for author info in shared chats
|
// Get current user for author info in shared chats
|
||||||
const { data: currentUser } = useAtomValue(currentUserAtom);
|
const { data: currentUser } = useAtomValue(currentUserAtom);
|
||||||
|
|
@ -327,6 +357,7 @@ export default function NewChatPage() {
|
||||||
clearPlanOwnerRegistry();
|
clearPlanOwnerRegistry();
|
||||||
closeReportPanel();
|
closeReportPanel();
|
||||||
closeEditorPanel();
|
closeEditorPanel();
|
||||||
|
resetAgentActionMap();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (urlChatId > 0) {
|
if (urlChatId > 0) {
|
||||||
|
|
@ -395,6 +426,7 @@ export default function NewChatPage() {
|
||||||
removeChatTab,
|
removeChatTab,
|
||||||
searchSpaceId,
|
searchSpaceId,
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
|
resetAgentActionMap,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
|
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
|
||||||
|
|
@ -655,11 +687,14 @@ export default function NewChatPage() {
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
contentParts: [],
|
contentParts: [],
|
||||||
currentTextPartIndex: -1,
|
currentTextPartIndex: -1,
|
||||||
|
currentReasoningPartIndex: -1,
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { contentParts, toolCallIndices } = contentPartsState;
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
let wasInterrupted = false;
|
let wasInterrupted = false;
|
||||||
let tokenUsageData: Record<string, unknown> | null = null;
|
let tokenUsageData: Record<string, unknown> | null = null;
|
||||||
|
// Captured from ``data-turn-info`` at stream start.
|
||||||
|
let streamedChatTurnId: string | null = null;
|
||||||
|
|
||||||
// Add placeholder assistant message
|
// Add placeholder assistant message
|
||||||
setMessages((prev) => [
|
setMessages((prev) => [
|
||||||
|
|
@ -752,21 +787,52 @@ export default function NewChatPage() {
|
||||||
scheduleFlush();
|
scheduleFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "reasoning-delta":
|
||||||
|
appendReasoning(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "reasoning-end":
|
||||||
|
endReasoning(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
addStepSeparator(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
toolsWithUI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
{},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available": {
|
case "tool-input-available": {
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
args: parsed.input || {},
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
toolsWithUI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
@ -774,7 +840,10 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
case "tool-output-available": {
|
case "tool-output-available": {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
result: parsed.output,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
markInterruptsCompleted(contentParts);
|
markInterruptsCompleted(contentParts);
|
||||||
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
|
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
|
||||||
const idx = toolCallIndices.get(parsed.toolCallId);
|
const idx = toolCallIndices.get(parsed.toolCallId);
|
||||||
|
|
@ -880,6 +949,50 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-action-log": {
|
||||||
|
const al = parsed.data;
|
||||||
|
const matchedToolCallId = al.lc_tool_call_id
|
||||||
|
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
||||||
|
: null;
|
||||||
|
upsertAgentAction({
|
||||||
|
action: {
|
||||||
|
id: al.id,
|
||||||
|
threadId: currentThreadId,
|
||||||
|
lcToolCallId: al.lc_tool_call_id,
|
||||||
|
chatTurnId: al.chat_turn_id,
|
||||||
|
toolName: al.tool_name,
|
||||||
|
reversible: al.reversible,
|
||||||
|
reverseDescriptorPresent: al.reverse_descriptor_present,
|
||||||
|
error: al.error,
|
||||||
|
revertedByActionId: null,
|
||||||
|
isRevertAction: false,
|
||||||
|
createdAt: al.created_at,
|
||||||
|
},
|
||||||
|
toolCallId: matchedToolCallId,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-action-log-updated": {
|
||||||
|
updateAgentActionReversible({
|
||||||
|
id: parsed.data.id,
|
||||||
|
reversible: parsed.data.reversible,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-turn-info": {
|
||||||
|
streamedChatTurnId = parsed.data.chat_turn_id || null;
|
||||||
|
if (streamedChatTurnId) {
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "data-token-usage":
|
case "data-token-usage":
|
||||||
tokenUsageData = parsed.data;
|
tokenUsageData = parsed.data;
|
||||||
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
||||||
|
|
@ -900,13 +1013,18 @@ export default function NewChatPage() {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: finalContent,
|
content: finalContent,
|
||||||
token_usage: tokenUsageData ?? undefined,
|
token_usage: tokenUsageData ?? undefined,
|
||||||
|
turn_id: streamedChatTurnId,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update message ID from temporary to database ID so comments work immediately
|
// Update message ID from temporary to database ID so comments work immediately
|
||||||
const newMsgId = `msg-${savedMessage.id}`;
|
const newMsgId = `msg-${savedMessage.id}`;
|
||||||
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
|
||||||
|
: m
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
// Update pending interrupt with the new persisted message ID
|
// Update pending interrupt with the new persisted message ID
|
||||||
|
|
@ -929,7 +1047,9 @@ export default function NewChatPage() {
|
||||||
const hasContent = contentParts.some(
|
const hasContent = contentParts.some(
|
||||||
(part) =>
|
(part) =>
|
||||||
(part.type === "text" && part.text.length > 0) ||
|
(part.type === "text" && part.text.length > 0) ||
|
||||||
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
|
(part.type === "reasoning" && part.text.length > 0) ||
|
||||||
|
(part.type === "tool-call" &&
|
||||||
|
(toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
|
||||||
);
|
);
|
||||||
if (hasContent && currentThreadId) {
|
if (hasContent && currentThreadId) {
|
||||||
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
||||||
|
|
@ -937,12 +1057,17 @@ export default function NewChatPage() {
|
||||||
const savedMessage = await appendMessage(currentThreadId, {
|
const savedMessage = await appendMessage(currentThreadId, {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: partialContent,
|
content: partialContent,
|
||||||
|
turn_id: streamedChatTurnId,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update message ID from temporary to database ID
|
// Update message ID from temporary to database ID
|
||||||
const newMsgId = `msg-${savedMessage.id}`;
|
const newMsgId = `msg-${savedMessage.id}`;
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
|
||||||
|
: m
|
||||||
|
)
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("Failed to persist partial assistant message:", err);
|
console.error("Failed to persist partial assistant message:", err);
|
||||||
|
|
@ -1030,10 +1155,13 @@ export default function NewChatPage() {
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
contentParts: [],
|
contentParts: [],
|
||||||
currentTextPartIndex: -1,
|
currentTextPartIndex: -1,
|
||||||
|
currentReasoningPartIndex: -1,
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { contentParts, toolCallIndices } = contentPartsState;
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
let tokenUsageData: Record<string, unknown> | null = null;
|
let tokenUsageData: Record<string, unknown> | null = null;
|
||||||
|
// Captured from ``data-turn-info`` at stream start.
|
||||||
|
let streamedChatTurnId: string | null = null;
|
||||||
|
|
||||||
const existingMsg = messages.find((m) => m.id === assistantMsgId);
|
const existingMsg = messages.find((m) => m.id === assistantMsgId);
|
||||||
if (existingMsg && Array.isArray(existingMsg.content)) {
|
if (existingMsg && Array.isArray(existingMsg.content)) {
|
||||||
|
|
@ -1136,8 +1264,34 @@ export default function NewChatPage() {
|
||||||
scheduleFlush();
|
scheduleFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "reasoning-delta":
|
||||||
|
appendReasoning(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "reasoning-end":
|
||||||
|
endReasoning(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
addStepSeparator(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
toolsWithUI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
{},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -1145,6 +1299,7 @@ export default function NewChatPage() {
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
args: parsed.input || {},
|
args: parsed.input || {},
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
|
|
@ -1152,7 +1307,9 @@ export default function NewChatPage() {
|
||||||
toolsWithUI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
@ -1161,6 +1318,7 @@ export default function NewChatPage() {
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
result: parsed.output,
|
result: parsed.output,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
markInterruptsCompleted(contentParts);
|
markInterruptsCompleted(contentParts);
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
@ -1222,6 +1380,50 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-action-log": {
|
||||||
|
const al = parsed.data;
|
||||||
|
const matchedToolCallId = al.lc_tool_call_id
|
||||||
|
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
||||||
|
: null;
|
||||||
|
upsertAgentAction({
|
||||||
|
action: {
|
||||||
|
id: al.id,
|
||||||
|
threadId: resumeThreadId,
|
||||||
|
lcToolCallId: al.lc_tool_call_id,
|
||||||
|
chatTurnId: al.chat_turn_id,
|
||||||
|
toolName: al.tool_name,
|
||||||
|
reversible: al.reversible,
|
||||||
|
reverseDescriptorPresent: al.reverse_descriptor_present,
|
||||||
|
error: al.error,
|
||||||
|
revertedByActionId: null,
|
||||||
|
isRevertAction: false,
|
||||||
|
createdAt: al.created_at,
|
||||||
|
},
|
||||||
|
toolCallId: matchedToolCallId,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-action-log-updated": {
|
||||||
|
updateAgentActionReversible({
|
||||||
|
id: parsed.data.id,
|
||||||
|
reversible: parsed.data.reversible,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-turn-info": {
|
||||||
|
streamedChatTurnId = parsed.data.chat_turn_id || null;
|
||||||
|
if (streamedChatTurnId) {
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "data-token-usage":
|
case "data-token-usage":
|
||||||
tokenUsageData = parsed.data;
|
tokenUsageData = parsed.data;
|
||||||
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
||||||
|
|
@ -1241,11 +1443,16 @@ export default function NewChatPage() {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: finalContent,
|
content: finalContent,
|
||||||
token_usage: tokenUsageData ?? undefined,
|
token_usage: tokenUsageData ?? undefined,
|
||||||
|
turn_id: streamedChatTurnId,
|
||||||
});
|
});
|
||||||
const newMsgId = `msg-${savedMessage.id}`;
|
const newMsgId = `msg-${savedMessage.id}`;
|
||||||
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
|
||||||
|
: m
|
||||||
|
)
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("Failed to persist resumed assistant message:", err);
|
console.error("Failed to persist resumed assistant message:", err);
|
||||||
|
|
@ -1340,6 +1547,12 @@ export default function NewChatPage() {
|
||||||
editExtras?: {
|
editExtras?: {
|
||||||
userMessageContent: ThreadMessageLike["content"];
|
userMessageContent: ThreadMessageLike["content"];
|
||||||
userImages: NewChatUserImagePayload[];
|
userImages: NewChatUserImagePayload[];
|
||||||
|
},
|
||||||
|
editFromPosition?: {
|
||||||
|
/** Message id (numeric, parsed from ``msg-<n>``) to rewind to. */
|
||||||
|
fromMessageId?: number | null;
|
||||||
|
/** When true, revert reversible downstream actions before stream. */
|
||||||
|
revertActions?: boolean;
|
||||||
}
|
}
|
||||||
) => {
|
) => {
|
||||||
if (!threadId) {
|
if (!threadId) {
|
||||||
|
|
@ -1384,9 +1597,20 @@ export default function NewChatPage() {
|
||||||
userQueryToDisplay = newUserQuery;
|
userQueryToDisplay = newUserQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the last two messages (user + assistant) from the UI immediately
|
// Remove downstream messages from the UI immediately. The
|
||||||
// The backend will also delete them from the database
|
// backend will also delete them from the database.
|
||||||
|
//
|
||||||
|
// When an explicit ``fromMessageId`` is passed, slice from
|
||||||
|
// that message forward; otherwise fall back to the legacy
|
||||||
|
// "drop the last 2" behaviour.
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
|
if (editFromPosition?.fromMessageId != null) {
|
||||||
|
const targetId = `msg-${editFromPosition.fromMessageId}`;
|
||||||
|
const sliceIndex = prev.findIndex((m) => m.id === targetId);
|
||||||
|
if (sliceIndex >= 0) {
|
||||||
|
return prev.slice(0, sliceIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
if (prev.length >= 2) {
|
if (prev.length >= 2) {
|
||||||
return prev.slice(0, -2);
|
return prev.slice(0, -2);
|
||||||
}
|
}
|
||||||
|
|
@ -1406,11 +1630,16 @@ export default function NewChatPage() {
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
contentParts: [],
|
contentParts: [],
|
||||||
currentTextPartIndex: -1,
|
currentTextPartIndex: -1,
|
||||||
|
currentReasoningPartIndex: -1,
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { contentParts, toolCallIndices } = contentPartsState;
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
const batcher = new FrameBatchedUpdater();
|
const batcher = new FrameBatchedUpdater();
|
||||||
let tokenUsageData: Record<string, unknown> | null = null;
|
let tokenUsageData: Record<string, unknown> | null = null;
|
||||||
|
// Captured from ``data-turn-info`` at stream start; stamped
|
||||||
|
// onto persisted messages so future edits can locate the
|
||||||
|
// right LangGraph checkpoint.
|
||||||
|
let streamedChatTurnId: string | null = null;
|
||||||
|
|
||||||
// Add placeholder messages to UI
|
// Add placeholder messages to UI
|
||||||
// Always add back the user message (with new query for edit, or original content for reload)
|
// Always add back the user message (with new query for edit, or original content for reload)
|
||||||
|
|
@ -1449,6 +1678,16 @@ export default function NewChatPage() {
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
requestBody.user_images = editExtras?.userImages ?? [];
|
requestBody.user_images = editExtras?.userImages ?? [];
|
||||||
}
|
}
|
||||||
|
// Explicit edit-from-arbitrary-position. Only send
|
||||||
|
// ``from_message_id`` / ``revert_actions`` when the
|
||||||
|
// caller asked for them; otherwise the backend keeps the
|
||||||
|
// legacy "last 2 messages" behaviour for back-compat.
|
||||||
|
if (editFromPosition?.fromMessageId != null) {
|
||||||
|
requestBody.from_message_id = editFromPosition.fromMessageId;
|
||||||
|
if (editFromPosition.revertActions) {
|
||||||
|
requestBody.revert_actions = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
const response = await fetch(getRegenerateUrl(threadId), {
|
const response = await fetch(getRegenerateUrl(threadId), {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
|
|
@ -1481,28 +1720,62 @@ export default function NewChatPage() {
|
||||||
scheduleFlush();
|
scheduleFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "reasoning-delta":
|
||||||
|
appendReasoning(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "reasoning-end":
|
||||||
|
endReasoning(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
addStepSeparator(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
toolsWithUI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
{},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available":
|
case "tool-input-available":
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
args: parsed.input || {},
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
toolsWithUI,
|
toolsWithUI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
result: parsed.output,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
markInterruptsCompleted(contentParts);
|
markInterruptsCompleted(contentParts);
|
||||||
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
|
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
|
||||||
const idx = toolCallIndices.get(parsed.toolCallId);
|
const idx = toolCallIndices.get(parsed.toolCallId);
|
||||||
|
|
@ -1528,6 +1801,82 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-action-log": {
|
||||||
|
const al = parsed.data;
|
||||||
|
const matchedToolCallId = al.lc_tool_call_id
|
||||||
|
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
||||||
|
: null;
|
||||||
|
upsertAgentAction({
|
||||||
|
action: {
|
||||||
|
id: al.id,
|
||||||
|
threadId,
|
||||||
|
lcToolCallId: al.lc_tool_call_id,
|
||||||
|
chatTurnId: al.chat_turn_id,
|
||||||
|
toolName: al.tool_name,
|
||||||
|
reversible: al.reversible,
|
||||||
|
reverseDescriptorPresent: al.reverse_descriptor_present,
|
||||||
|
error: al.error,
|
||||||
|
revertedByActionId: null,
|
||||||
|
isRevertAction: false,
|
||||||
|
createdAt: al.created_at,
|
||||||
|
},
|
||||||
|
toolCallId: matchedToolCallId,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-action-log-updated": {
|
||||||
|
updateAgentActionReversible({
|
||||||
|
id: parsed.data.id,
|
||||||
|
reversible: parsed.data.reversible,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-turn-info": {
|
||||||
|
streamedChatTurnId = parsed.data.chat_turn_id || null;
|
||||||
|
if (streamedChatTurnId) {
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-revert-results": {
|
||||||
|
const summary = parsed.data;
|
||||||
|
// failureCount must include every "not undone" bucket
|
||||||
|
// (not_reversible, permission_denied, failed) so the
|
||||||
|
// toast's "X could not be rolled back" math matches
|
||||||
|
// the response invariant ``total === sum(counters)``.
|
||||||
|
// ``skipped`` rows are batch revert artefacts (revert
|
||||||
|
// rows themselves) and are not user-facing failures.
|
||||||
|
const failureCount =
|
||||||
|
summary.failed + summary.not_reversible + (summary.permission_denied ?? 0);
|
||||||
|
if (failureCount > 0) {
|
||||||
|
toast.warning(
|
||||||
|
`Pre-revert: ${summary.reverted}/${summary.total} undone, ${failureCount} could not be rolled back.`
|
||||||
|
);
|
||||||
|
} else if (summary.reverted > 0) {
|
||||||
|
toast.success(
|
||||||
|
summary.reverted === 1
|
||||||
|
? "Reverted 1 downstream action before regenerating."
|
||||||
|
: `Reverted ${summary.reverted} downstream actions before regenerating.`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (const r of summary.results) {
|
||||||
|
if (r.status === "reverted" || r.status === "already_reverted") {
|
||||||
|
markAgentActionReverted({
|
||||||
|
id: r.action_id,
|
||||||
|
newActionId: r.new_action_id ?? null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "data-token-usage":
|
case "data-token-usage":
|
||||||
tokenUsageData = parsed.data;
|
tokenUsageData = parsed.data;
|
||||||
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
|
||||||
|
|
@ -1552,12 +1901,17 @@ export default function NewChatPage() {
|
||||||
const savedUserMessage = await appendMessage(threadId, {
|
const savedUserMessage = await appendMessage(threadId, {
|
||||||
role: "user",
|
role: "user",
|
||||||
content: userContentToPersist,
|
content: userContentToPersist,
|
||||||
|
turn_id: streamedChatTurnId,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update user message ID to database ID
|
// Update user message ID to database ID
|
||||||
const newUserMsgId = `msg-${savedUserMessage.id}`;
|
const newUserMsgId = `msg-${savedUserMessage.id}`;
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m))
|
prev.map((m) =>
|
||||||
|
m.id === userMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id)
|
||||||
|
: m
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
// Persist assistant message
|
// Persist assistant message
|
||||||
|
|
@ -1565,12 +1919,17 @@ export default function NewChatPage() {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: finalContent,
|
content: finalContent,
|
||||||
token_usage: tokenUsageData ?? undefined,
|
token_usage: tokenUsageData ?? undefined,
|
||||||
|
turn_id: streamedChatTurnId,
|
||||||
});
|
});
|
||||||
|
|
||||||
const newMsgId = `msg-${savedMessage.id}`;
|
const newMsgId = `msg-${savedMessage.id}`;
|
||||||
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
|
||||||
|
: m
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
trackChatResponseReceived(searchSpaceId, threadId);
|
trackChatResponseReceived(searchSpaceId, threadId);
|
||||||
|
|
@ -1608,7 +1967,14 @@ export default function NewChatPage() {
|
||||||
[threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI]
|
[threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Handle editing a message - truncates history and regenerates with new query
|
// Handle editing a message - truncates history and regenerates with new query.
|
||||||
|
//
|
||||||
|
// When ``message.sourceId`` is set (the assistant-ui way to say
|
||||||
|
// "this edit replaces an older message"), we pin
|
||||||
|
// ``from_message_id`` so the backend rewinds to the right LangGraph
|
||||||
|
// checkpoint instead of relying on the legacy "last 2 messages"
|
||||||
|
// rewind. We also count downstream reversible actions and prompt the
|
||||||
|
// user to revert / continue / cancel before regenerating.
|
||||||
const onEdit = useCallback(
|
const onEdit = useCallback(
|
||||||
async (message: AppendMessage) => {
|
async (message: AppendMessage) => {
|
||||||
const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []);
|
const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []);
|
||||||
|
|
@ -1619,9 +1985,95 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
const userMessageContent = message.content as unknown as ThreadMessageLike["content"];
|
const userMessageContent = message.content as unknown as ThreadMessageLike["content"];
|
||||||
await handleRegenerate(queryForApi, { userMessageContent, userImages });
|
|
||||||
|
// ``sourceId`` per @assistant-ui/core's ``AppendMessage`` is
|
||||||
|
// "the ID of the message that was edited". Parse the numeric
|
||||||
|
// suffix so we can map it back to a DB row.
|
||||||
|
const sourceId = (message as { sourceId?: string }).sourceId;
|
||||||
|
const fromMessageId =
|
||||||
|
sourceId && /^msg-\d+$/.test(sourceId)
|
||||||
|
? Number.parseInt(sourceId.replace(/^msg-/, ""), 10)
|
||||||
|
: null;
|
||||||
|
|
||||||
|
if (fromMessageId == null) {
|
||||||
|
// No source id (or non-DB id) — fall back to today's
|
||||||
|
// last-2 behaviour. The user gets the legacy edit flow.
|
||||||
|
await handleRegenerate(queryForApi, { userMessageContent, userImages });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-flight: count reversible downstream actions so we can
|
||||||
|
// auto-skip the dialog for harmless edits.
|
||||||
|
//
|
||||||
|
// "Downstream" means messages AFTER the edited one. The
|
||||||
|
// previous slice ``messages.slice(editedIndex)`` included
|
||||||
|
// the edited message itself in both the total
|
||||||
|
// count and the reversibility scan (any actions on the
|
||||||
|
// edited turn would be double-counted). Slice from
|
||||||
|
// ``editedIndex + 1`` so the dialog text matches reality:
|
||||||
|
// "N downstream messages will be dropped".
|
||||||
|
const editedIndex = messages.findIndex((m) => m.id === `msg-${fromMessageId}`);
|
||||||
|
let downstreamReversibleCount = 0;
|
||||||
|
let downstreamTotalCount = 0;
|
||||||
|
if (editedIndex >= 0) {
|
||||||
|
const downstream = messages.slice(editedIndex + 1);
|
||||||
|
downstreamTotalCount = downstream.length;
|
||||||
|
const seenTurns = new Set<string>();
|
||||||
|
for (const m of downstream) {
|
||||||
|
const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } };
|
||||||
|
const tid = meta.custom?.chatTurnId;
|
||||||
|
if (!tid || seenTurns.has(tid)) continue;
|
||||||
|
seenTurns.add(tid);
|
||||||
|
const turnActions = agentActionsByChatTurnId.get(tid) ?? [];
|
||||||
|
for (const a of turnActions) {
|
||||||
|
if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) {
|
||||||
|
downstreamReversibleCount += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (downstreamReversibleCount === 0) {
|
||||||
|
// Nothing to revert — submit silently.
|
||||||
|
await handleRegenerate(
|
||||||
|
queryForApi,
|
||||||
|
{ userMessageContent, userImages },
|
||||||
|
{ fromMessageId, revertActions: false }
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setEditDialogState({
|
||||||
|
fromMessageId,
|
||||||
|
userQuery: queryForApi,
|
||||||
|
userMessageContent,
|
||||||
|
userImages,
|
||||||
|
downstreamReversibleCount,
|
||||||
|
downstreamTotalCount,
|
||||||
|
});
|
||||||
},
|
},
|
||||||
[handleRegenerate]
|
[handleRegenerate, messages, agentActionsByChatTurnId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleEditDialogChoice = useCallback(
|
||||||
|
async (choice: EditMessageDialogChoice) => {
|
||||||
|
const pending = editDialogState;
|
||||||
|
if (!pending) return;
|
||||||
|
setEditDialogState(null);
|
||||||
|
if (choice === "cancel") return;
|
||||||
|
await handleRegenerate(
|
||||||
|
pending.userQuery,
|
||||||
|
{
|
||||||
|
userMessageContent: pending.userMessageContent,
|
||||||
|
userImages: pending.userImages,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fromMessageId: pending.fromMessageId,
|
||||||
|
revertActions: choice === "revert",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[editDialogState, handleRegenerate]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Handle reloading/refreshing the last AI response
|
// Handle reloading/refreshing the last AI response
|
||||||
|
|
@ -1671,6 +2123,7 @@ export default function NewChatPage() {
|
||||||
<TokenUsageProvider store={tokenUsageStore}>
|
<TokenUsageProvider store={tokenUsageStore}>
|
||||||
<AssistantRuntimeProvider runtime={runtime}>
|
<AssistantRuntimeProvider runtime={runtime}>
|
||||||
<ThinkingStepsDataUI />
|
<ThinkingStepsDataUI />
|
||||||
|
<StepSeparatorDataUI />
|
||||||
<div key={searchSpaceId} className="flex h-full overflow-hidden">
|
<div key={searchSpaceId} className="flex h-full overflow-hidden">
|
||||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||||
<Thread />
|
<Thread />
|
||||||
|
|
@ -1679,6 +2132,15 @@ export default function NewChatPage() {
|
||||||
<MobileEditorPanel />
|
<MobileEditorPanel />
|
||||||
<MobileHitlEditPanel />
|
<MobileHitlEditPanel />
|
||||||
</div>
|
</div>
|
||||||
|
<EditMessageDialog
|
||||||
|
open={editDialogState !== null}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
if (!open) setEditDialogState(null);
|
||||||
|
}}
|
||||||
|
downstreamReversibleCount={editDialogState?.downstreamReversibleCount ?? 0}
|
||||||
|
downstreamTotalCount={editDialogState?.downstreamTotalCount ?? 0}
|
||||||
|
onChoose={handleEditDialogChoice}
|
||||||
|
/>
|
||||||
</AssistantRuntimeProvider>
|
</AssistantRuntimeProvider>
|
||||||
</TokenUsageProvider>
|
</TokenUsageProvider>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
194
surfsense_web/atoms/chat/agent-actions.atom.ts
Normal file
194
surfsense_web/atoms/chat/agent-actions.atom.ts
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { atom } from "jotai";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimal per-row projection of ``AgentActionLog`` that the tool card
|
||||||
|
* needs to decide whether to render a Revert button.
|
||||||
|
*
|
||||||
|
* Fields are deliberately a subset of the full ``AgentAction`` so the
|
||||||
|
* SSE side-channel (``data-action-log`` / ``data-action-log-updated``)
|
||||||
|
* can populate them without depending on the REST endpoint
|
||||||
|
* ``GET /threads/.../actions`` (which 503s when
|
||||||
|
* ``SURFSENSE_ENABLE_ACTION_LOG`` is off).
|
||||||
|
*/
|
||||||
|
export interface AgentActionLite {
|
||||||
|
id: number;
|
||||||
|
threadId: number | null;
|
||||||
|
lcToolCallId: string | null;
|
||||||
|
chatTurnId: string | null;
|
||||||
|
toolName: string;
|
||||||
|
reversible: boolean;
|
||||||
|
reverseDescriptorPresent: boolean;
|
||||||
|
error: boolean;
|
||||||
|
revertedByActionId: number | null;
|
||||||
|
isRevertAction: boolean;
|
||||||
|
createdAt: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart
|
||||||
|
* tool-call.langchainToolCallId``).
|
||||||
|
*/
|
||||||
|
export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parallel map keyed off the synthetic chat-card ``toolCallId``
|
||||||
|
* (``call_<run-id>``) so ``ToolFallback`` (which only receives the
|
||||||
|
* synthetic id from assistant-ui) can join its card to the action log.
|
||||||
|
*
|
||||||
|
* Both maps are kept in sync by ``upsertAgentActionAtom``.
|
||||||
|
*/
|
||||||
|
export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer
|
||||||
|
* "how many reversible actions does this assistant turn contain?" in
|
||||||
|
* O(1). Each entry's array is ordered by insertion (which
|
||||||
|
* for a single turn matches ``created_at`` because action-log writes
|
||||||
|
* happen synchronously).
|
||||||
|
*/
|
||||||
|
export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Action to upsert one ``AgentActionLite`` row.
|
||||||
|
*
|
||||||
|
* ``toolCallId`` is the synthetic card id (``call_<run-id>`` from
|
||||||
|
* ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the
|
||||||
|
* action is indexed under BOTH ids so the tool card can perform the
|
||||||
|
* lookup without going via the streaming state.
|
||||||
|
*/
|
||||||
|
export const upsertAgentActionAtom = atom(
|
||||||
|
null,
|
||||||
|
(_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => {
|
||||||
|
const { action, toolCallId } = payload;
|
||||||
|
const upsertInto = (
|
||||||
|
prev: Map<string, AgentActionLite>,
|
||||||
|
key: string
|
||||||
|
): Map<string, AgentActionLite> => {
|
||||||
|
const next = new Map(prev);
|
||||||
|
const existing = next.get(key);
|
||||||
|
next.set(key, {
|
||||||
|
...action,
|
||||||
|
// Preserve the local "reverted" bookkeeping if a reversibility
|
||||||
|
// flip arrives AFTER the user already reverted via the REST
|
||||||
|
// route. We never want a stale ``reversible=true`` event to
|
||||||
|
// resurrect a Reverted card.
|
||||||
|
revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId,
|
||||||
|
isRevertAction: existing?.isRevertAction ?? action.isRevertAction,
|
||||||
|
});
|
||||||
|
return next;
|
||||||
|
};
|
||||||
|
if (action.lcToolCallId) {
|
||||||
|
set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string));
|
||||||
|
}
|
||||||
|
if (toolCallId) {
|
||||||
|
set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId));
|
||||||
|
}
|
||||||
|
if (action.chatTurnId) {
|
||||||
|
set(agentActionsByChatTurnIdAtom, (prev) => {
|
||||||
|
const next = new Map(prev);
|
||||||
|
const turnId = action.chatTurnId as string;
|
||||||
|
const existing = next.get(turnId) ?? [];
|
||||||
|
const priorEntry = existing.find((row) => row.id === action.id);
|
||||||
|
const merged: AgentActionLite = {
|
||||||
|
...action,
|
||||||
|
revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId,
|
||||||
|
isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction,
|
||||||
|
};
|
||||||
|
const others = existing.filter((row) => row.id !== action.id);
|
||||||
|
next.set(turnId, [...others, merged]);
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
function mutateById(
|
||||||
|
prev: Map<string, AgentActionLite>,
|
||||||
|
id: number,
|
||||||
|
mutator: (entry: AgentActionLite) => AgentActionLite
|
||||||
|
): Map<string, AgentActionLite> {
|
||||||
|
let mutated = false;
|
||||||
|
const next = new Map(prev);
|
||||||
|
for (const [key, value] of next) {
|
||||||
|
if (value.id === id) {
|
||||||
|
next.set(key, mutator(value));
|
||||||
|
mutated = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mutated ? next : prev;
|
||||||
|
}
|
||||||
|
|
||||||
|
function mutateByIdInTurnIndex(
|
||||||
|
prev: Map<string, AgentActionLite[]>,
|
||||||
|
id: number,
|
||||||
|
mutator: (entry: AgentActionLite) => AgentActionLite
|
||||||
|
): Map<string, AgentActionLite[]> {
|
||||||
|
let mutated = false;
|
||||||
|
const next = new Map(prev);
|
||||||
|
for (const [key, list] of next) {
|
||||||
|
let listMutated = false;
|
||||||
|
const updated = list.map((row) => {
|
||||||
|
if (row.id === id) {
|
||||||
|
listMutated = true;
|
||||||
|
return mutator(row);
|
||||||
|
}
|
||||||
|
return row;
|
||||||
|
});
|
||||||
|
if (listMutated) {
|
||||||
|
next.set(key, updated);
|
||||||
|
mutated = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mutated ? next : prev;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Action to flip an existing entry's ``reversible`` flag, keyed by the
|
||||||
|
* AgentActionLog row id (the SSE ``data-action-log-updated`` payload
|
||||||
|
* does NOT carry ``lcToolCallId``).
|
||||||
|
*/
|
||||||
|
export const updateAgentActionReversibleAtom = atom(
|
||||||
|
null,
|
||||||
|
(_get, set, payload: { id: number; reversible: boolean }) => {
|
||||||
|
const apply = (entry: AgentActionLite): AgentActionLite => ({
|
||||||
|
...entry,
|
||||||
|
reversible: payload.reversible,
|
||||||
|
});
|
||||||
|
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
||||||
|
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
||||||
|
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/** Action to mark an existing entry as reverted (post-revert call). */
|
||||||
|
export const markAgentActionRevertedAtom = atom(
|
||||||
|
null,
|
||||||
|
(_get, set, payload: { id: number; newActionId: number | null }) => {
|
||||||
|
const apply = (entry: AgentActionLite): AgentActionLite => ({
|
||||||
|
...entry,
|
||||||
|
revertedByActionId: payload.newActionId ?? -1,
|
||||||
|
});
|
||||||
|
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
||||||
|
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
||||||
|
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */
|
||||||
|
export const markAgentActionsRevertedBatchAtom = atom(
|
||||||
|
null,
|
||||||
|
(_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => {
|
||||||
|
for (const entry of payload.entries) {
|
||||||
|
set(markAgentActionRevertedAtom, entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/** Reset all maps (e.g. when the active thread changes). */
|
||||||
|
export const resetAgentActionMapAtom = atom(null, (_get, set) => {
|
||||||
|
set(agentActionByLcIdAtom, new Map());
|
||||||
|
set(agentActionByToolCallIdAtom, new Map());
|
||||||
|
set(agentActionsByChatTurnIdAtom, new Map());
|
||||||
|
});
|
||||||
|
|
@ -33,6 +33,8 @@ import {
|
||||||
useAllCitationMetadata,
|
useAllCitationMetadata,
|
||||||
} from "@/components/assistant-ui/citation-metadata-context";
|
} from "@/components/assistant-ui/citation-metadata-context";
|
||||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||||
|
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
|
||||||
|
import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button";
|
||||||
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
|
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
|
||||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
|
|
@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => {
|
||||||
<MessagePrimitive.Parts
|
<MessagePrimitive.Parts
|
||||||
components={{
|
components={{
|
||||||
Text: MarkdownText,
|
Text: MarkdownText,
|
||||||
|
Reasoning: ReasoningMessagePart,
|
||||||
tools: {
|
tools: {
|
||||||
by_name: {
|
by_name: {
|
||||||
generate_report: GenerateReportToolUI,
|
generate_report: GenerateReportToolUI,
|
||||||
|
|
@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => {
|
||||||
const isLast = useAuiState((s) => s.message.isLast);
|
const isLast = useAuiState((s) => s.message.isLast);
|
||||||
const aui = useAui();
|
const aui = useAui();
|
||||||
const api = useElectronAPI();
|
const api = useElectronAPI();
|
||||||
|
// Surface the persisted ``chat_turn_id`` so the per-turn revert
|
||||||
|
// affordance can scope to just this message's actions. Streamed
|
||||||
|
// turns get their id once the assistant message is hydrated/finalised.
|
||||||
|
const chatTurnId = useAuiState(({ message }) => {
|
||||||
|
const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined;
|
||||||
|
return meta?.custom?.chatTurnId ?? null;
|
||||||
|
});
|
||||||
|
|
||||||
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
|
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
|
||||||
|
|
||||||
|
|
@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => {
|
||||||
</TooltipIconButton>
|
</TooltipIconButton>
|
||||||
)}
|
)}
|
||||||
<MessageInfoDropdown />
|
<MessageInfoDropdown />
|
||||||
|
<div className="ml-auto">
|
||||||
|
<RevertTurnButton chatTurnId={chatTurnId} />
|
||||||
|
</div>
|
||||||
</ActionBarPrimitive.Root>
|
</ActionBarPrimitive.Root>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Confirmation dialog shown when the user edits a message that has
|
||||||
|
* reversible downstream actions. Three buttons:
|
||||||
|
*
|
||||||
|
* • "Revert all & resubmit" — POST regenerate with revert_actions=true
|
||||||
|
* • "Continue without revert" — POST regenerate with revert_actions=false
|
||||||
|
* • "Cancel" — abort the edit entirely
|
||||||
|
*
|
||||||
|
* The dialog is auto-skipped when zero reversible downstream actions
|
||||||
|
* exist (the caller checks first via ``downstreamReversibleCount``).
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
|
export type EditMessageDialogChoice = "revert" | "continue" | "cancel";
|
||||||
|
|
||||||
|
export interface EditMessageDialogProps {
|
||||||
|
open: boolean;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
downstreamReversibleCount: number;
|
||||||
|
downstreamTotalCount: number;
|
||||||
|
onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function EditMessageDialog({
|
||||||
|
open,
|
||||||
|
onOpenChange,
|
||||||
|
downstreamReversibleCount,
|
||||||
|
downstreamTotalCount,
|
||||||
|
onChoose,
|
||||||
|
}: EditMessageDialogProps) {
|
||||||
|
const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null);
|
||||||
|
|
||||||
|
// The parent's ``handleEditDialogChoice`` calls
|
||||||
|
// ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``.
|
||||||
|
// That collapses the dialog (Radix unmounts it) while ``onChoose``
|
||||||
|
// is still awaiting the long-running stream. Without this guard,
|
||||||
|
// the ``finally { setBusy(null) }`` below ran after unmount and
|
||||||
|
// produced a "state update on unmounted component" dev warning.
|
||||||
|
const mountedRef = useRef(true);
|
||||||
|
useEffect(() => {
|
||||||
|
mountedRef.current = true;
|
||||||
|
return () => {
|
||||||
|
mountedRef.current = false;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handle = async (choice: EditMessageDialogChoice) => {
|
||||||
|
setBusy(choice);
|
||||||
|
try {
|
||||||
|
await onChoose(choice);
|
||||||
|
} finally {
|
||||||
|
if (mountedRef.current) {
|
||||||
|
setBusy(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog open={open} onOpenChange={onOpenChange}>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Edit this message?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This edit drops {downstreamTotalCount} downstream message
|
||||||
|
{downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "}
|
||||||
|
action
|
||||||
|
{downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can
|
||||||
|
be rolled back. Pick how to handle them before regenerating.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<div className="grid gap-2">
|
||||||
|
<Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}>
|
||||||
|
{busy === "revert"
|
||||||
|
? "Reverting & resubmitting…"
|
||||||
|
: `Revert ${downstreamReversibleCount} action${
|
||||||
|
downstreamReversibleCount === 1 ? "" : "s"
|
||||||
|
} & resubmit`}
|
||||||
|
</Button>
|
||||||
|
<Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}>
|
||||||
|
{busy === "continue" ? "Resubmitting…" : "Continue without reverting"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<AlertDialogFooter className="sm:justify-start">
|
||||||
|
<AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}>
|
||||||
|
Cancel
|
||||||
|
</AlertDialogCancel>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import type { ReasoningMessagePartComponent } from "@assistant-ui/react";
|
||||||
|
import { ChevronRightIcon } from "lucide-react";
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the structured `reasoning` part emitted by the backend's
|
||||||
|
* stream-parity v2 path (A1).
|
||||||
|
*
|
||||||
|
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
|
||||||
|
* - collapsed by default;
|
||||||
|
* - auto-expanded while the part is still `running`;
|
||||||
|
* - auto-collapsed once status flips to `complete`.
|
||||||
|
*
|
||||||
|
* The component is registered via the `Reasoning` slot on
|
||||||
|
* `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the
|
||||||
|
* exact ordinal position of the reasoning block in the message content
|
||||||
|
* array (i.e. above the assistant text that follows it).
|
||||||
|
*/
|
||||||
|
export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => {
|
||||||
|
const isRunning = status?.type === "running";
|
||||||
|
const [isOpen, setIsOpen] = useState(() => isRunning);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isRunning) {
|
||||||
|
setIsOpen(true);
|
||||||
|
} else if (status?.type === "complete") {
|
||||||
|
setIsOpen(false);
|
||||||
|
}
|
||||||
|
}, [isRunning, status?.type]);
|
||||||
|
|
||||||
|
const headerLabel = useMemo(() => {
|
||||||
|
if (isRunning) return "Thinking";
|
||||||
|
if (status?.type === "incomplete") return "Thinking interrupted";
|
||||||
|
return "Thought";
|
||||||
|
}, [isRunning, status?.type]);
|
||||||
|
|
||||||
|
if (!text || text.length === 0) {
|
||||||
|
if (!isRunning) return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
|
||||||
|
<div className="rounded-lg">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setIsOpen((prev) => !prev)}
|
||||||
|
className={cn(
|
||||||
|
"flex w-full items-center gap-1.5 text-left text-sm transition-colors",
|
||||||
|
"text-muted-foreground hover:text-foreground"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isRunning ? (
|
||||||
|
<TextShimmerLoader text={headerLabel} size="sm" />
|
||||||
|
) : (
|
||||||
|
<span>{headerLabel}</span>
|
||||||
|
)}
|
||||||
|
<ChevronRightIcon
|
||||||
|
className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"grid transition-[grid-template-rows] duration-300 ease-out",
|
||||||
|
isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="overflow-hidden">
|
||||||
|
<div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word">
|
||||||
|
{text}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
232
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
232
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* "Revert turn" button rendered at the bottom of every completed
|
||||||
|
* assistant turn that has at least one reversible action.
|
||||||
|
*
|
||||||
|
* The button reads the action map keyed by ``chat_turn_id`` from the
|
||||||
|
* SSE side-channel (``data-action-log`` events). It shows a confirmation
|
||||||
|
* dialog summarising "N reversible / M total" and, on confirm, calls
|
||||||
|
* ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
*
|
||||||
|
* The route returns a per-action result list and never collapses the
|
||||||
|
* batch into a 4xx — so we render any failed/not_reversible rows inline
|
||||||
|
* with their messages.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useAtomValue, useSetAtom } from "jotai";
|
||||||
|
import { selectAtom } from "jotai/utils";
|
||||||
|
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
|
import { useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
type AgentActionLite,
|
||||||
|
agentActionsByChatTurnIdAtom,
|
||||||
|
markAgentActionsRevertedBatchAtom,
|
||||||
|
} from "@/atoms/chat/agent-actions.atom";
|
||||||
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
AlertDialogTrigger,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
agentActionsApiService,
|
||||||
|
type RevertTurnActionResult,
|
||||||
|
} from "@/lib/apis/agent-actions-api.service";
|
||||||
|
import { AppError } from "@/lib/error";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface RevertTurnButtonProps {
|
||||||
|
chatTurnId: string | null | undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatToolName(name: string): string {
|
||||||
|
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a
|
||||||
|
// stable reference when the turn has no recorded actions yet. Without
|
||||||
|
// this every render allocates a fresh ``[]`` and Jotai's
|
||||||
|
// equality check would re-render the button on unrelated turn updates.
|
||||||
|
const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]);
|
||||||
|
|
||||||
|
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
||||||
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
|
const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom);
|
||||||
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
|
const [resultsOpen, setResultsOpen] = useState(false);
|
||||||
|
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
|
||||||
|
|
||||||
|
// Subscribe ONLY to the slice of the global action map that belongs
|
||||||
|
// to ``chatTurnId``. Previously the button read the whole
|
||||||
|
// ``agentActionsByChatTurnIdAtom``, which meant every action
|
||||||
|
// upsert (one per tool call) re-rendered every Revert button on
|
||||||
|
// the page. With ``selectAtom`` we re-render only when our turn's
|
||||||
|
// list reference changes — and the upsert/mark atoms produce a
|
||||||
|
// fresh list reference for the affected turn only.
|
||||||
|
const sliceAtom = useMemo(
|
||||||
|
() =>
|
||||||
|
selectAtom(
|
||||||
|
agentActionsByChatTurnIdAtom,
|
||||||
|
(turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS
|
||||||
|
),
|
||||||
|
[chatTurnId]
|
||||||
|
);
|
||||||
|
const actions = useAtomValue(sliceAtom);
|
||||||
|
|
||||||
|
const reversibleCount = useMemo(
|
||||||
|
() =>
|
||||||
|
actions.filter(
|
||||||
|
(a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error
|
||||||
|
).length,
|
||||||
|
[actions]
|
||||||
|
);
|
||||||
|
const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]);
|
||||||
|
|
||||||
|
if (!chatTurnId) return null;
|
||||||
|
if (reversibleCount === 0) return null;
|
||||||
|
const threadId = session?.threadId;
|
||||||
|
if (!threadId) return null;
|
||||||
|
|
||||||
|
const handleRevertTurn = async () => {
|
||||||
|
setIsReverting(true);
|
||||||
|
try {
|
||||||
|
const response = await agentActionsApiService.revertTurn(threadId, chatTurnId);
|
||||||
|
setResults(response.results);
|
||||||
|
const revertedEntries = response.results
|
||||||
|
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
|
||||||
|
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
|
||||||
|
if (revertedEntries.length > 0) {
|
||||||
|
markRevertedBatch({ entries: revertedEntries });
|
||||||
|
}
|
||||||
|
if (response.status === "ok") {
|
||||||
|
toast.success(
|
||||||
|
response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Every "not undone" bucket counts as a failure for the
|
||||||
|
// user-facing summary. ``skipped`` rows are batch
|
||||||
|
// artefacts (revert rows themselves) and intentionally
|
||||||
|
// excluded from the failure tally.
|
||||||
|
const failureCount =
|
||||||
|
response.failed + response.not_reversible + (response.permission_denied ?? 0);
|
||||||
|
toast.warning(
|
||||||
|
`Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.`
|
||||||
|
);
|
||||||
|
setResultsOpen(true);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof AppError && err.status === 503) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const message =
|
||||||
|
err instanceof AppError
|
||||||
|
? err.message
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: "Failed to revert turn.";
|
||||||
|
toast.error(message);
|
||||||
|
} finally {
|
||||||
|
setIsReverting(false);
|
||||||
|
setConfirmOpen(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||||
|
<AlertDialogTrigger asChild>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
className="text-muted-foreground hover:text-foreground gap-1.5"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setConfirmOpen(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RotateCcw className="size-3.5" />
|
||||||
|
<span>Revert turn</span>
|
||||||
|
<span className="text-xs tabular-nums opacity-70">
|
||||||
|
{reversibleCount}/{totalCount}
|
||||||
|
</span>
|
||||||
|
</Button>
|
||||||
|
</AlertDialogTrigger>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert this turn?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This will undo {reversibleCount} of {totalCount} action
|
||||||
|
{totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and
|
||||||
|
any read-only actions are preserved. Some rows may not be reversible — partial success
|
||||||
|
is normal.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
handleRevertTurn();
|
||||||
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
|
>
|
||||||
|
{isReverting ? "Reverting…" : "Revert turn"}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
|
||||||
|
<AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert results</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
Some actions could not be reverted. Review per-row outcomes below.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<ul className="max-h-72 overflow-y-auto space-y-2 text-sm">
|
||||||
|
{results.map((r) => (
|
||||||
|
<RevertResultRow key={r.action_id} result={r} />
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function RevertResultRow({ result }: { result: RevertTurnActionResult }) {
|
||||||
|
const isOk = result.status === "reverted" || result.status === "already_reverted";
|
||||||
|
const Icon = isOk ? CheckIcon : XCircleIcon;
|
||||||
|
return (
|
||||||
|
<li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2">
|
||||||
|
<Icon
|
||||||
|
className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")}
|
||||||
|
/>
|
||||||
|
<div className="min-w-0 flex-1">
|
||||||
|
<p className="font-medium truncate">
|
||||||
|
{formatToolName(result.tool_name)}{" "}
|
||||||
|
<span className="ml-1 text-xs text-muted-foreground">
|
||||||
|
{result.status.replace(/_/g, " ")}
|
||||||
|
</span>
|
||||||
|
</p>
|
||||||
|
{(result.message || result.error) && (
|
||||||
|
<p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</li>
|
||||||
|
);
|
||||||
|
}
|
||||||
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { makeAssistantDataUI } from "@assistant-ui/react";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders a thin horizontal divider between model steps within a single
|
||||||
|
* assistant turn. The data part is pushed by `addStepSeparator` in
|
||||||
|
* `streaming-state.ts` whenever a `start-step` SSE event arrives after
|
||||||
|
* the message already has non-step content.
|
||||||
|
*
|
||||||
|
* Today the backend emits one `start-step` / `finish-step` pair per turn,
|
||||||
|
* so most messages won't contain a separator. The renderer is wired up so
|
||||||
|
* the planned per-model-step refactor (A2 follow-up) can light up without
|
||||||
|
* touching the persistence path.
|
||||||
|
*/
|
||||||
|
function StepSeparatorDataRenderer() {
|
||||||
|
return (
|
||||||
|
<div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2">
|
||||||
|
<div className="border-t border-border/60" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const StepSeparatorDataUI = makeAssistantDataUI({
|
||||||
|
name: "step-separator",
|
||||||
|
render: StepSeparatorDataRenderer,
|
||||||
|
});
|
||||||
|
|
@ -1,12 +1,33 @@
|
||||||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
|
import { useAtomValue, useSetAtom } from "jotai";
|
||||||
|
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
import { useMemo, useState } from "react";
|
import { useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
agentActionByToolCallIdAtom,
|
||||||
|
markAgentActionRevertedAtom,
|
||||||
|
} from "@/atoms/chat/agent-actions.atom";
|
||||||
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
import {
|
import {
|
||||||
DoomLoopApprovalToolUI,
|
DoomLoopApprovalToolUI,
|
||||||
isDoomLoopInterrupt,
|
isDoomLoopInterrupt,
|
||||||
} from "@/components/tool-ui/doom-loop-approval";
|
} from "@/components/tool-ui/doom-loop-approval";
|
||||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
AlertDialogTrigger,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
||||||
|
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||||
|
import { AppError } from "@/lib/error";
|
||||||
import { isInterruptResult } from "@/lib/hitl";
|
import { isInterruptResult } from "@/lib/hitl";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
|
@ -14,7 +35,99 @@ function formatToolName(name: string): string {
|
||||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline Revert button rendered on a tool card when the matching
|
||||||
|
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
|
||||||
|
* Reads from the SSE side-channel atom keyed by the synthetic
|
||||||
|
* ``toolCallId`` so it lights up even when ``GET /threads/.../actions``
|
||||||
|
* is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503).
|
||||||
|
*/
|
||||||
|
function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
||||||
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
|
const actionMap = useAtomValue(agentActionByToolCallIdAtom);
|
||||||
|
const markReverted = useSetAtom(markAgentActionRevertedAtom);
|
||||||
|
const action = actionMap.get(toolCallId);
|
||||||
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
|
|
||||||
|
if (!action) return null;
|
||||||
|
if (!action.reversible) return null;
|
||||||
|
if (action.revertedByActionId !== null) return null;
|
||||||
|
if (action.isRevertAction) return null;
|
||||||
|
if (action.error) return null;
|
||||||
|
const threadId = session?.threadId;
|
||||||
|
if (!threadId) return null;
|
||||||
|
|
||||||
|
const handleRevert = async () => {
|
||||||
|
setIsReverting(true);
|
||||||
|
try {
|
||||||
|
const response = await agentActionsApiService.revert(threadId, action.id);
|
||||||
|
markReverted({ id: action.id, newActionId: response.new_action_id ?? null });
|
||||||
|
toast.success(response.message || "Action reverted.");
|
||||||
|
} catch (err) {
|
||||||
|
// 503 means revert is gated off on this deployment — hide the
|
||||||
|
// button silently rather than nagging the user. Any other error
|
||||||
|
// is surfaced as a toast so the operator can investigate.
|
||||||
|
if (err instanceof AppError && err.status === 503) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const message =
|
||||||
|
err instanceof AppError
|
||||||
|
? err.message
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: "Failed to revert action.";
|
||||||
|
toast.error(message);
|
||||||
|
} finally {
|
||||||
|
setIsReverting(false);
|
||||||
|
setConfirmOpen(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||||
|
<AlertDialogTrigger asChild>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
className="gap-1.5"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setConfirmOpen(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RotateCcw className="size-3.5" />
|
||||||
|
Revert
|
||||||
|
</Button>
|
||||||
|
</AlertDialogTrigger>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This will undo <span className="font-medium">{formatToolName(action.toolName)}</span>{" "}
|
||||||
|
and append a new audit entry. Chat history is preserved — only the tool's effects on
|
||||||
|
your knowledge base or connectors will be reversed where possible.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
handleRevert();
|
||||||
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
|
>
|
||||||
|
{isReverting ? "Reverting…" : "Revert"}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||||
|
toolCallId,
|
||||||
toolName,
|
toolName,
|
||||||
argsText,
|
argsText,
|
||||||
result,
|
result,
|
||||||
|
|
@ -145,6 +258,9 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
<div className="flex justify-end">
|
||||||
|
<ToolCardRevertButton toolCallId={toolCallId} />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import {
|
||||||
import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile";
|
import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile";
|
||||||
import { ShieldCheck } from "lucide-react";
|
import { ShieldCheck } from "lucide-react";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
|
||||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||||
import {
|
import {
|
||||||
createTokenUsageStore,
|
createTokenUsageStore,
|
||||||
|
|
@ -17,10 +18,13 @@ import {
|
||||||
} from "@/components/assistant-ui/token-usage-context";
|
} from "@/components/assistant-ui/token-usage-context";
|
||||||
import { useAnonymousMode } from "@/contexts/anonymous-mode";
|
import { useAnonymousMode } from "@/contexts/anonymous-mode";
|
||||||
import {
|
import {
|
||||||
|
addStepSeparator,
|
||||||
addToolCall,
|
addToolCall,
|
||||||
|
appendReasoning,
|
||||||
appendText,
|
appendText,
|
||||||
buildContentForUI,
|
buildContentForUI,
|
||||||
type ContentPartsState,
|
type ContentPartsState,
|
||||||
|
endReasoning,
|
||||||
FrameBatchedUpdater,
|
FrameBatchedUpdater,
|
||||||
readSSEStream,
|
readSSEStream,
|
||||||
type ThinkingStepData,
|
type ThinkingStepData,
|
||||||
|
|
@ -32,7 +36,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events";
|
||||||
import { FreeModelSelector } from "./free-model-selector";
|
import { FreeModelSelector } from "./free-model-selector";
|
||||||
import { FreeThread } from "./free-thread";
|
import { FreeThread } from "./free-thread";
|
||||||
|
|
||||||
const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]);
|
// Render all tool calls via ToolFallback; backend keeps persisted
|
||||||
|
// payloads bounded by summarising / truncating outputs.
|
||||||
|
const TOOLS_WITH_UI = "all" as const;
|
||||||
const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? "";
|
const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? "";
|
||||||
|
|
||||||
/** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */
|
/** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */
|
||||||
|
|
@ -125,6 +131,7 @@ export function FreeChatPage() {
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
contentParts: [],
|
contentParts: [],
|
||||||
currentTextPartIndex: -1,
|
currentTextPartIndex: -1,
|
||||||
|
currentReasoningPartIndex: -1,
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { toolCallIndices } = contentPartsState;
|
const { toolCallIndices } = contentPartsState;
|
||||||
|
|
@ -148,28 +155,62 @@ export function FreeChatPage() {
|
||||||
scheduleFlush();
|
scheduleFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "reasoning-delta":
|
||||||
|
appendReasoning(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "reasoning-end":
|
||||||
|
endReasoning(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
addStepSeparator(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
break;
|
||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
TOOLS_WITH_UI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
{},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available":
|
case "tool-input-available":
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
args: parsed.input || {},
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
addToolCall(
|
addToolCall(
|
||||||
contentPartsState,
|
contentPartsState,
|
||||||
TOOLS_WITH_UI,
|
TOOLS_WITH_UI,
|
||||||
parsed.toolCallId,
|
parsed.toolCallId,
|
||||||
parsed.toolName,
|
parsed.toolName,
|
||||||
parsed.input || {}
|
parsed.input || {},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
result: parsed.output,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -369,6 +410,7 @@ export function FreeChatPage() {
|
||||||
<TokenUsageProvider store={tokenUsageStore}>
|
<TokenUsageProvider store={tokenUsageStore}>
|
||||||
<AssistantRuntimeProvider runtime={runtime}>
|
<AssistantRuntimeProvider runtime={runtime}>
|
||||||
<ThinkingStepsDataUI />
|
<ThinkingStepsDataUI />
|
||||||
|
<StepSeparatorDataUI />
|
||||||
<div className="flex h-full flex-col overflow-hidden">
|
<div className="flex h-full flex-col overflow-hidden">
|
||||||
<div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4">
|
<div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4">
|
||||||
<FreeModelSelector />
|
<FreeModelSelector />
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AssistantRuntimeProvider } from "@assistant-ui/react";
|
import { AssistantRuntimeProvider } from "@assistant-ui/react";
|
||||||
|
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
|
||||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||||
import { Navbar } from "@/components/homepage/navbar";
|
import { Navbar } from "@/components/homepage/navbar";
|
||||||
import { ReportPanel } from "@/components/report-panel/report-panel";
|
import { ReportPanel } from "@/components/report-panel/report-panel";
|
||||||
|
|
@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) {
|
||||||
<Navbar scrolledBgClassName={navbarScrolledBg} />
|
<Navbar scrolledBgClassName={navbarScrolledBg} />
|
||||||
<AssistantRuntimeProvider runtime={runtime}>
|
<AssistantRuntimeProvider runtime={runtime}>
|
||||||
<ThinkingStepsDataUI />
|
<ThinkingStepsDataUI />
|
||||||
|
<StepSeparatorDataUI />
|
||||||
<div className="flex h-screen pt-16 overflow-hidden">
|
<div className="flex h-screen pt-16 overflow-hidden">
|
||||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||||
<PublicThread footer={<PublicChatFooter shareToken={shareToken} />} />
|
<PublicThread footer={<PublicChatFooter shareToken={shareToken} />} />
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import Image from "next/image";
|
||||||
import { type FC, type ReactNode, useState } from "react";
|
import { type FC, type ReactNode, useState } from "react";
|
||||||
import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context";
|
import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context";
|
||||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||||
|
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
|
||||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
import { GenerateImageToolUI } from "@/components/tool-ui/generate-image";
|
import { GenerateImageToolUI } from "@/components/tool-ui/generate-image";
|
||||||
|
|
@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => {
|
||||||
<MessagePrimitive.Parts
|
<MessagePrimitive.Parts
|
||||||
components={{
|
components={{
|
||||||
Text: MarkdownText,
|
Text: MarkdownText,
|
||||||
|
Reasoning: ReasoningMessagePart,
|
||||||
tools: {
|
tools: {
|
||||||
by_name: {
|
by_name: {
|
||||||
generate_podcast: GeneratePodcastToolUI,
|
generate_podcast: GeneratePodcastToolUI,
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,112 @@
|
||||||
import {
|
import {
|
||||||
BookOpen,
|
BookOpen,
|
||||||
Brain,
|
Brain,
|
||||||
|
Calendar,
|
||||||
|
Check,
|
||||||
|
FileEdit,
|
||||||
|
FilePlus,
|
||||||
FileText,
|
FileText,
|
||||||
FileUser,
|
FileUser,
|
||||||
|
FileX,
|
||||||
Film,
|
Film,
|
||||||
|
FolderPlus,
|
||||||
|
FolderTree,
|
||||||
|
FolderX,
|
||||||
Globe,
|
Globe,
|
||||||
ImageIcon,
|
ImageIcon,
|
||||||
|
ListTodo,
|
||||||
type LucideIcon,
|
type LucideIcon,
|
||||||
|
Mail,
|
||||||
|
MessagesSquare,
|
||||||
|
Move,
|
||||||
|
Plus,
|
||||||
Podcast,
|
Podcast,
|
||||||
ScanLine,
|
ScanLine,
|
||||||
|
Search,
|
||||||
|
Send,
|
||||||
|
Trash2,
|
||||||
Wrench,
|
Wrench,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Every tool now renders a card via ``ToolFallback``. The icon map is
|
||||||
|
* keyed on the canonical backend tool name (registered in
|
||||||
|
* ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown
|
||||||
|
* names fall back to the generic ``Wrench`` icon so the card still
|
||||||
|
* communicates "this is a tool call".
|
||||||
|
*/
|
||||||
const TOOL_ICONS: Record<string, LucideIcon> = {
|
const TOOL_ICONS: Record<string, LucideIcon> = {
|
||||||
|
// Generators
|
||||||
generate_podcast: Podcast,
|
generate_podcast: Podcast,
|
||||||
generate_video_presentation: Film,
|
generate_video_presentation: Film,
|
||||||
generate_report: FileText,
|
generate_report: FileText,
|
||||||
generate_resume: FileUser,
|
generate_resume: FileUser,
|
||||||
generate_image: ImageIcon,
|
generate_image: ImageIcon,
|
||||||
|
display_image: ImageIcon,
|
||||||
|
// Web / search
|
||||||
scrape_webpage: ScanLine,
|
scrape_webpage: ScanLine,
|
||||||
web_search: Globe,
|
web_search: Globe,
|
||||||
search_surfsense_docs: BookOpen,
|
search_surfsense_docs: BookOpen,
|
||||||
|
// Memory
|
||||||
update_memory: Brain,
|
update_memory: Brain,
|
||||||
|
// Filesystem (built-in deepagent + middleware)
|
||||||
|
read_file: FileText,
|
||||||
|
write_file: FilePlus,
|
||||||
|
edit_file: FileEdit,
|
||||||
|
move_file: Move,
|
||||||
|
rm: FileX,
|
||||||
|
rmdir: FolderX,
|
||||||
|
mkdir: FolderPlus,
|
||||||
|
ls: FolderTree,
|
||||||
|
write_todos: ListTodo,
|
||||||
|
// Calendar
|
||||||
|
search_calendar_events: Search,
|
||||||
|
create_calendar_event: Calendar,
|
||||||
|
update_calendar_event: Calendar,
|
||||||
|
delete_calendar_event: Calendar,
|
||||||
|
// Gmail
|
||||||
|
search_gmail: Search,
|
||||||
|
read_gmail_email: Mail,
|
||||||
|
create_gmail_draft: Mail,
|
||||||
|
update_gmail_draft: FileEdit,
|
||||||
|
send_gmail_email: Send,
|
||||||
|
trash_gmail_email: Trash2,
|
||||||
|
// Notion / Confluence pages
|
||||||
|
create_notion_page: FilePlus,
|
||||||
|
update_notion_page: FileEdit,
|
||||||
|
delete_notion_page: FileX,
|
||||||
|
create_confluence_page: FilePlus,
|
||||||
|
update_confluence_page: FileEdit,
|
||||||
|
delete_confluence_page: FileX,
|
||||||
|
// Linear / Jira issues
|
||||||
|
create_linear_issue: Plus,
|
||||||
|
update_linear_issue: FileEdit,
|
||||||
|
delete_linear_issue: Trash2,
|
||||||
|
create_jira_issue: Plus,
|
||||||
|
update_jira_issue: FileEdit,
|
||||||
|
delete_jira_issue: Trash2,
|
||||||
|
// Drive-like file connectors
|
||||||
|
create_google_drive_file: FilePlus,
|
||||||
|
delete_google_drive_file: FileX,
|
||||||
|
create_dropbox_file: FilePlus,
|
||||||
|
delete_dropbox_file: FileX,
|
||||||
|
create_onedrive_file: FilePlus,
|
||||||
|
delete_onedrive_file: FileX,
|
||||||
|
// Chat connectors
|
||||||
|
list_discord_channels: MessagesSquare,
|
||||||
|
read_discord_messages: MessagesSquare,
|
||||||
|
send_discord_message: Send,
|
||||||
|
list_teams_channels: MessagesSquare,
|
||||||
|
read_teams_messages: MessagesSquare,
|
||||||
|
send_teams_message: Send,
|
||||||
|
// Luma
|
||||||
|
list_luma_events: Calendar,
|
||||||
|
read_luma_event: Calendar,
|
||||||
|
create_luma_event: Calendar,
|
||||||
|
// Misc
|
||||||
|
get_connected_accounts: Check,
|
||||||
|
execute: Wrench,
|
||||||
|
execute_code: Wrench,
|
||||||
};
|
};
|
||||||
|
|
||||||
export function getToolIcon(name: string): LucideIcon {
|
export function getToolIcon(name: string): LucideIcon {
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({
|
||||||
reverse_of: z.number().nullable(),
|
reverse_of: z.number().nullable(),
|
||||||
reverted_by_action_id: z.number().nullable(),
|
reverted_by_action_id: z.number().nullable(),
|
||||||
is_revert_action: z.boolean(),
|
is_revert_action: z.boolean(),
|
||||||
|
// Correlation ids added in migration 135. The LangChain
|
||||||
|
// ``tool_call_id`` joins this row to the chat tool card via the
|
||||||
|
// ``data-action-log.lc_tool_call_id`` SSE event, and
|
||||||
|
// ``chat_turn_id`` keys the per-turn revert endpoint.
|
||||||
|
tool_call_id: z.string().nullable().optional(),
|
||||||
|
chat_turn_id: z.string().nullable().optional(),
|
||||||
created_at: z.string(),
|
created_at: z.string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({
|
||||||
|
|
||||||
export type RevertResponse = z.infer<typeof RevertResponseSchema>;
|
export type RevertResponse = z.infer<typeof RevertResponseSchema>;
|
||||||
|
|
||||||
|
// Per-turn batch revert. The route never returns whole-batch 4xx;
|
||||||
|
// partial success is the common case and surfaced as
|
||||||
|
// ``status === "partial"`` with a per-action result list.
|
||||||
|
const RevertTurnActionResultSchema = z.object({
|
||||||
|
action_id: z.number(),
|
||||||
|
tool_name: z.string(),
|
||||||
|
status: z.enum([
|
||||||
|
"reverted",
|
||||||
|
"already_reverted",
|
||||||
|
"not_reversible",
|
||||||
|
"permission_denied",
|
||||||
|
"failed",
|
||||||
|
"skipped",
|
||||||
|
]),
|
||||||
|
message: z.string().nullable().optional(),
|
||||||
|
new_action_id: z.number().nullable().optional(),
|
||||||
|
error: z.string().nullable().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>;
|
||||||
|
|
||||||
|
const RevertTurnResponseSchema = z.object({
|
||||||
|
status: z.enum(["ok", "partial"]),
|
||||||
|
chat_turn_id: z.string(),
|
||||||
|
total: z.number(),
|
||||||
|
reverted: z.number(),
|
||||||
|
already_reverted: z.number(),
|
||||||
|
not_reversible: z.number(),
|
||||||
|
// ``permission_denied`` and ``skipped`` are first-class counters so
|
||||||
|
// ``total === reverted + already_reverted +
|
||||||
|
// not_reversible + permission_denied + failed + skipped`` always
|
||||||
|
// holds. ``.default(0)`` keeps the schema backwards-compatible
|
||||||
|
// with older deployments that haven't shipped the response model
|
||||||
|
// update yet.
|
||||||
|
permission_denied: z.number().default(0),
|
||||||
|
failed: z.number(),
|
||||||
|
skipped: z.number().default(0),
|
||||||
|
results: z.array(RevertTurnActionResultSchema),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>;
|
||||||
|
|
||||||
class AgentActionsApiService {
|
class AgentActionsApiService {
|
||||||
listForThread = async (
|
listForThread = async (
|
||||||
threadId: number,
|
threadId: number,
|
||||||
|
|
@ -59,6 +107,14 @@ class AgentActionsApiService {
|
||||||
{ body: {} }
|
{ body: {} }
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => {
|
||||||
|
return baseApiService.post(
|
||||||
|
`/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`,
|
||||||
|
RevertTurnResponseSchema,
|
||||||
|
{ body: {} }
|
||||||
|
);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export const agentActionsApiService = new AgentActionsApiService();
|
export const agentActionsApiService = new AgentActionsApiService();
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
||||||
}
|
}
|
||||||
|
|
||||||
const metadata =
|
const metadata =
|
||||||
msg.author_id || msg.token_usage
|
msg.author_id || msg.token_usage || msg.turn_id
|
||||||
? {
|
? {
|
||||||
custom: {
|
custom: {
|
||||||
...(msg.author_id && {
|
...(msg.author_id && {
|
||||||
|
|
@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
...(msg.token_usage && { usage: msg.token_usage }),
|
...(msg.token_usage && { usage: msg.token_usage }),
|
||||||
|
// Surface ``chat_turn_id`` so the assistant message
|
||||||
|
// footer can scope its "Revert turn" button to just
|
||||||
|
// this turn's actions. Null on legacy rows.
|
||||||
|
...(msg.turn_id && { chatTurnId: msg.turn_id }),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
|
||||||
|
|
@ -9,21 +9,42 @@ export interface ThinkingStepData {
|
||||||
|
|
||||||
export type ContentPart =
|
export type ContentPart =
|
||||||
| { type: "text"; text: string }
|
| { type: "text"; text: string }
|
||||||
|
| { type: "reasoning"; text: string }
|
||||||
| {
|
| {
|
||||||
type: "tool-call";
|
type: "tool-call";
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
result?: unknown;
|
result?: unknown;
|
||||||
|
/**
|
||||||
|
* Authoritative LangChain ``tool_call.id`` propagated by the backend
|
||||||
|
* via ``langchainToolCallId`` on tool-input-start/available and
|
||||||
|
* tool-output-available events. Used to join a card to the
|
||||||
|
* matching ``AgentActionLog`` row exposed by
|
||||||
|
* ``GET /threads/{id}/actions`` and the streamed
|
||||||
|
* ``data-action-log`` events.
|
||||||
|
*/
|
||||||
|
langchainToolCallId?: string;
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "data-thinking-steps";
|
type: "data-thinking-steps";
|
||||||
data: { steps: ThinkingStepData[] };
|
data: { steps: ThinkingStepData[] };
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Between-step separator. Pushed by `addStepSeparator` when
|
||||||
|
* a `start-step` SSE event arrives AFTER the message already
|
||||||
|
* has non-step content. Rendered by `StepSeparatorDataUI`
|
||||||
|
* (see assistant-ui/step-separator.tsx).
|
||||||
|
*/
|
||||||
|
type: "data-step-separator";
|
||||||
|
data: { stepIndex: number };
|
||||||
};
|
};
|
||||||
|
|
||||||
export interface ContentPartsState {
|
export interface ContentPartsState {
|
||||||
contentParts: ContentPart[];
|
contentParts: ContentPart[];
|
||||||
currentTextPartIndex: number;
|
currentTextPartIndex: number;
|
||||||
|
currentReasoningPartIndex: number;
|
||||||
toolCallIndices: Map<string, number>;
|
toolCallIndices: Map<string, number>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,6 +95,9 @@ export function updateThinkingSteps(
|
||||||
if (state.currentTextPartIndex >= 0) {
|
if (state.currentTextPartIndex >= 0) {
|
||||||
state.currentTextPartIndex += 1;
|
state.currentTextPartIndex += 1;
|
||||||
}
|
}
|
||||||
|
if (state.currentReasoningPartIndex >= 0) {
|
||||||
|
state.currentReasoningPartIndex += 1;
|
||||||
|
}
|
||||||
for (const [id, idx] of state.toolCallIndices) {
|
for (const [id, idx] of state.toolCallIndices) {
|
||||||
state.toolCallIndices.set(id, idx + 1);
|
state.toolCallIndices.set(id, idx + 1);
|
||||||
}
|
}
|
||||||
|
|
@ -131,6 +155,12 @@ export class FrameBatchedUpdater {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function appendText(state: ContentPartsState, delta: string): void {
|
export function appendText(state: ContentPartsState, delta: string): void {
|
||||||
|
// First text delta after a reasoning block: close the reasoning so
|
||||||
|
// the assistant-ui renderer treats them as separate parts (the
|
||||||
|
// reasoning block collapses; the answer streams below).
|
||||||
|
if (state.currentReasoningPartIndex >= 0) {
|
||||||
|
state.currentReasoningPartIndex = -1;
|
||||||
|
}
|
||||||
if (
|
if (
|
||||||
state.currentTextPartIndex >= 0 &&
|
state.currentTextPartIndex >= 0 &&
|
||||||
state.contentParts[state.currentTextPartIndex]?.type === "text"
|
state.contentParts[state.currentTextPartIndex]?.type === "text"
|
||||||
|
|
@ -143,36 +173,129 @@ export function appendText(state: ContentPartsState, delta: string): void {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function appendReasoning(state: ContentPartsState, delta: string): void {
|
||||||
|
// Symmetric to appendText: open a fresh reasoning block on first
|
||||||
|
// delta, then accumulate into it. ``endReasoning`` simply closes
|
||||||
|
// the active block; subsequent reasoning deltas would open a new
|
||||||
|
// one (matching ``text-start/end`` semantics on the wire).
|
||||||
|
if (state.currentTextPartIndex >= 0) {
|
||||||
|
state.currentTextPartIndex = -1;
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
state.currentReasoningPartIndex >= 0 &&
|
||||||
|
state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning"
|
||||||
|
) {
|
||||||
|
(
|
||||||
|
state.contentParts[state.currentReasoningPartIndex] as {
|
||||||
|
type: "reasoning";
|
||||||
|
text: string;
|
||||||
|
}
|
||||||
|
).text += delta;
|
||||||
|
} else {
|
||||||
|
state.contentParts.push({ type: "reasoning", text: delta });
|
||||||
|
state.currentReasoningPartIndex = state.contentParts.length - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function endReasoning(state: ContentPartsState): void {
|
||||||
|
state.currentReasoningPartIndex = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function addStepSeparator(state: ContentPartsState): void {
|
||||||
|
// Push a divider between consecutive model steps within a single
|
||||||
|
// assistant turn. We only emit it when the message already has
|
||||||
|
// non-step content (so the FIRST step of a turn doesn't
|
||||||
|
// generate a leading separator) and when the previous part isn't
|
||||||
|
// itself a separator (defensive against duplicate `start-step`
|
||||||
|
// events).
|
||||||
|
const hasContent = state.contentParts.some(
|
||||||
|
(p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call"
|
||||||
|
);
|
||||||
|
if (!hasContent) return;
|
||||||
|
const last = state.contentParts[state.contentParts.length - 1];
|
||||||
|
if (last && last.type === "data-step-separator") return;
|
||||||
|
|
||||||
|
const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length;
|
||||||
|
state.contentParts.push({ type: "data-step-separator", data: { stepIndex } });
|
||||||
|
state.currentTextPartIndex = -1;
|
||||||
|
state.currentReasoningPartIndex = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allowlist of tool names that should produce a UI tool card. The
|
||||||
|
* sentinel ``"all"`` matches every tool — we dropped the legacy
|
||||||
|
* ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the
|
||||||
|
* generic ``ToolFallback``. The backend's ``format_thinking_step``
|
||||||
|
* summarisation and the defensive ``result_length``-only default for
|
||||||
|
* unknown tools keep persisted message JSON from ballooning.
|
||||||
|
*/
|
||||||
|
export type ToolUIGate = Set<string> | "all";
|
||||||
|
|
||||||
|
function _toolPasses(gate: ToolUIGate, toolName: string): boolean {
|
||||||
|
return gate === "all" || gate.has(toolName);
|
||||||
|
}
|
||||||
|
|
||||||
export function addToolCall(
|
export function addToolCall(
|
||||||
state: ContentPartsState,
|
state: ContentPartsState,
|
||||||
toolsWithUI: Set<string>,
|
toolsWithUI: ToolUIGate,
|
||||||
toolCallId: string,
|
toolCallId: string,
|
||||||
toolName: string,
|
toolName: string,
|
||||||
args: Record<string, unknown>,
|
args: Record<string, unknown>,
|
||||||
force = false
|
force = false,
|
||||||
|
langchainToolCallId?: string
|
||||||
): void {
|
): void {
|
||||||
if (force || toolsWithUI.has(toolName)) {
|
if (force || _toolPasses(toolsWithUI, toolName)) {
|
||||||
state.contentParts.push({
|
state.contentParts.push({
|
||||||
type: "tool-call",
|
type: "tool-call",
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName,
|
toolName,
|
||||||
args,
|
args,
|
||||||
|
...(langchainToolCallId ? { langchainToolCallId } : {}),
|
||||||
});
|
});
|
||||||
state.toolCallIndices.set(toolCallId, state.contentParts.length - 1);
|
state.toolCallIndices.set(toolCallId, state.contentParts.length - 1);
|
||||||
state.currentTextPartIndex = -1;
|
state.currentTextPartIndex = -1;
|
||||||
|
state.currentReasoningPartIndex = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse-lookup helper used by the SSE ``data-action-log`` handler:
|
||||||
|
* given the LangChain ``tool_call.id`` (set on the content part as
|
||||||
|
* ``langchainToolCallId``), return the synthetic ``toolCallId`` that
|
||||||
|
* the chat tool card uses (``call_<run-id>``). Returns ``null`` when no
|
||||||
|
* matching tool card has been seen yet — the action is still recorded
|
||||||
|
* in the LC-id-keyed atom so the card can pick it up when it eventually
|
||||||
|
* arrives.
|
||||||
|
*/
|
||||||
|
export function findToolCallIdByLcId(
|
||||||
|
state: ContentPartsState,
|
||||||
|
lcToolCallId: string
|
||||||
|
): string | null {
|
||||||
|
for (const part of state.contentParts) {
|
||||||
|
if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) {
|
||||||
|
return part.toolCallId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
export function updateToolCall(
|
export function updateToolCall(
|
||||||
state: ContentPartsState,
|
state: ContentPartsState,
|
||||||
toolCallId: string,
|
toolCallId: string,
|
||||||
update: { args?: Record<string, unknown>; result?: unknown }
|
update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string }
|
||||||
): void {
|
): void {
|
||||||
const index = state.toolCallIndices.get(toolCallId);
|
const index = state.toolCallIndices.get(toolCallId);
|
||||||
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
|
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
|
||||||
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
|
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
|
||||||
if (update.args) tc.args = update.args;
|
if (update.args) tc.args = update.args;
|
||||||
if (update.result !== undefined) tc.result = update.result;
|
if (update.result !== undefined) tc.result = update.result;
|
||||||
|
// Only backfill langchainToolCallId if not already set — the
|
||||||
|
// authoritative ``on_tool_end`` value should override an earlier
|
||||||
|
// best-effort match, but a NULL late-arriving value should not
|
||||||
|
// blow away a known good early one.
|
||||||
|
if (update.langchainToolCallId && !tc.langchainToolCallId) {
|
||||||
|
tc.langchainToolCallId = update.langchainToolCallId;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,13 +307,15 @@ function _hasInterruptResult(part: ContentPart): boolean {
|
||||||
|
|
||||||
export function buildContentForUI(
|
export function buildContentForUI(
|
||||||
state: ContentPartsState,
|
state: ContentPartsState,
|
||||||
toolsWithUI: Set<string>
|
toolsWithUI: ToolUIGate
|
||||||
): ThreadMessageLike["content"] {
|
): ThreadMessageLike["content"] {
|
||||||
const filtered = state.contentParts.filter((part) => {
|
const filtered = state.contentParts.filter((part) => {
|
||||||
if (part.type === "text") return part.text.length > 0;
|
if (part.type === "text") return part.text.length > 0;
|
||||||
|
if (part.type === "reasoning") return part.text.length > 0;
|
||||||
if (part.type === "tool-call")
|
if (part.type === "tool-call")
|
||||||
return toolsWithUI.has(part.toolName) || _hasInterruptResult(part);
|
return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part);
|
||||||
if (part.type === "data-thinking-steps") return true;
|
if (part.type === "data-thinking-steps") return true;
|
||||||
|
if (part.type === "data-step-separator") return true;
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
return filtered.length > 0
|
return filtered.length > 0
|
||||||
|
|
@ -200,20 +325,28 @@ export function buildContentForUI(
|
||||||
|
|
||||||
export function buildContentForPersistence(
|
export function buildContentForPersistence(
|
||||||
state: ContentPartsState,
|
state: ContentPartsState,
|
||||||
toolsWithUI: Set<string>
|
toolsWithUI: ToolUIGate
|
||||||
): unknown[] {
|
): unknown[] {
|
||||||
const parts: unknown[] = [];
|
const parts: unknown[] = [];
|
||||||
|
|
||||||
for (const part of state.contentParts) {
|
for (const part of state.contentParts) {
|
||||||
if (part.type === "text" && part.text.length > 0) {
|
if (part.type === "text" && part.text.length > 0) {
|
||||||
parts.push(part);
|
parts.push(part);
|
||||||
|
} else if (part.type === "reasoning" && part.text.length > 0) {
|
||||||
|
// Persist reasoning blocks so a chat reload re-renders the
|
||||||
|
// collapsed thinking section instead of
|
||||||
|
// silently dropping it (mirrors the data-thinking-steps
|
||||||
|
// branch above).
|
||||||
|
parts.push(part);
|
||||||
} else if (
|
} else if (
|
||||||
part.type === "tool-call" &&
|
part.type === "tool-call" &&
|
||||||
(toolsWithUI.has(part.toolName) || _hasInterruptResult(part))
|
(_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part))
|
||||||
) {
|
) {
|
||||||
parts.push(part);
|
parts.push(part);
|
||||||
} else if (part.type === "data-thinking-steps") {
|
} else if (part.type === "data-thinking-steps") {
|
||||||
parts.push(part);
|
parts.push(part);
|
||||||
|
} else if (part.type === "data-step-separator") {
|
||||||
|
parts.push(part);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,23 +354,122 @@ export function buildContentForPersistence(
|
||||||
}
|
}
|
||||||
|
|
||||||
export type SSEEvent =
|
export type SSEEvent =
|
||||||
| { type: "text-delta"; delta: string }
|
| { type: "start"; messageId?: string }
|
||||||
| { type: "tool-input-start"; toolCallId: string; toolName: string }
|
| { type: "finish" }
|
||||||
|
| { type: "start-step" }
|
||||||
|
| { type: "finish-step" }
|
||||||
|
| { type: "text-start"; id: string }
|
||||||
|
| { type: "text-delta"; id?: string; delta: string }
|
||||||
|
| { type: "text-end"; id: string }
|
||||||
|
| { type: "reasoning-start"; id: string }
|
||||||
|
| { type: "reasoning-delta"; id?: string; delta: string }
|
||||||
|
| { type: "reasoning-end"; id: string }
|
||||||
|
| {
|
||||||
|
type: "tool-input-start";
|
||||||
|
toolCallId: string;
|
||||||
|
toolName: string;
|
||||||
|
/** Authoritative LangChain ``tool_call.id``. Optional. */
|
||||||
|
langchainToolCallId?: string;
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
type: "tool-input-available";
|
type: "tool-input-available";
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
input: Record<string, unknown>;
|
input: Record<string, unknown>;
|
||||||
|
langchainToolCallId?: string;
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "tool-output-available";
|
type: "tool-output-available";
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
output: Record<string, unknown>;
|
output: Record<string, unknown>;
|
||||||
|
/** Authoritative LangChain ``tool_call.id`` extracted from
|
||||||
|
* ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards
|
||||||
|
* that didn't get the id at tool-input-start time. */
|
||||||
|
langchainToolCallId?: string;
|
||||||
}
|
}
|
||||||
| { type: "data-thinking-step"; data: ThinkingStepData }
|
| { type: "data-thinking-step"; data: ThinkingStepData }
|
||||||
| { type: "data-thread-title-update"; data: { threadId: number; title: string } }
|
| { type: "data-thread-title-update"; data: { threadId: number; title: string } }
|
||||||
| { type: "data-interrupt-request"; data: Record<string, unknown> }
|
| { type: "data-interrupt-request"; data: Record<string, unknown> }
|
||||||
| { type: "data-documents-updated"; data: Record<string, unknown> }
|
| { type: "data-documents-updated"; data: Record<string, unknown> }
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* A freshly committed AgentActionLog row. Frontend stores
|
||||||
|
* this in a Map keyed off ``lc_tool_call_id`` so the chat
|
||||||
|
* tool card can light up its Revert button.
|
||||||
|
*/
|
||||||
|
type: "data-action-log";
|
||||||
|
data: {
|
||||||
|
id: number;
|
||||||
|
lc_tool_call_id: string | null;
|
||||||
|
chat_turn_id: string | null;
|
||||||
|
tool_name: string;
|
||||||
|
reversible: boolean;
|
||||||
|
reverse_descriptor_present: boolean;
|
||||||
|
created_at: string | null;
|
||||||
|
error: boolean;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Reversibility flipped (filesystem op SAVEPOINT committed;
|
||||||
|
* cf. ``kb_persistence._dispatch_reversibility_update``).
|
||||||
|
*/
|
||||||
|
type: "data-action-log-updated";
|
||||||
|
data: { id: number; reversible: boolean };
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Emitted at the start of every stream so the frontend can
|
||||||
|
* stamp the per-turn correlation id onto the in-flight
|
||||||
|
* assistant message and replay it via
|
||||||
|
* ``appendMessage``. Pure-text turns never produce
|
||||||
|
* action-log events; this event guarantees the frontend
|
||||||
|
* always learns the turn id.
|
||||||
|
*/
|
||||||
|
type: "data-turn-info";
|
||||||
|
data: { chat_turn_id: string };
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Best-effort revert pass that ran BEFORE this regeneration.
|
||||||
|
* Per-action results are forwarded to the UI so the user
|
||||||
|
* can see which downstream actions were rolled
|
||||||
|
* back vs which couldn't be undone.
|
||||||
|
*/
|
||||||
|
type: "data-revert-results";
|
||||||
|
data: {
|
||||||
|
status: "ok" | "partial";
|
||||||
|
chat_turn_ids: string[];
|
||||||
|
total: number;
|
||||||
|
reverted: number;
|
||||||
|
already_reverted: number;
|
||||||
|
not_reversible: number;
|
||||||
|
/**
|
||||||
|
* ``permission_denied`` and ``skipped`` are first-class
|
||||||
|
* counters so the response invariant
|
||||||
|
* ``total === sum(counters)`` always holds. Optional
|
||||||
|
* for forward compatibility with older backends; the
|
||||||
|
* frontend treats missing values as ``0``.
|
||||||
|
*/
|
||||||
|
permission_denied?: number;
|
||||||
|
failed: number;
|
||||||
|
skipped?: number;
|
||||||
|
results: Array<{
|
||||||
|
action_id: number;
|
||||||
|
tool_name: string;
|
||||||
|
status:
|
||||||
|
| "reverted"
|
||||||
|
| "already_reverted"
|
||||||
|
| "not_reversible"
|
||||||
|
| "permission_denied"
|
||||||
|
| "failed"
|
||||||
|
| "skipped";
|
||||||
|
message?: string | null;
|
||||||
|
new_action_id?: number | null;
|
||||||
|
error?: string | null;
|
||||||
|
}>;
|
||||||
|
};
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
type: "data-token-usage";
|
type: "data-token-usage";
|
||||||
data: {
|
data: {
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,11 @@ export interface MessageRecord {
|
||||||
author_display_name?: string | null;
|
author_display_name?: string | null;
|
||||||
author_avatar_url?: string | null;
|
author_avatar_url?: string | null;
|
||||||
token_usage?: TokenUsageSummary | null;
|
token_usage?: TokenUsageSummary | null;
|
||||||
|
// Per-turn correlation id from ``configurable.turn_id`` at streaming
|
||||||
|
// time (added in migration 136). Used by the per-turn revert
|
||||||
|
// endpoint and edit-from-arbitrary-position. Nullable on legacy
|
||||||
|
// rows that predate the column.
|
||||||
|
turn_id?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ThreadListResponse {
|
export interface ThreadListResponse {
|
||||||
|
|
@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Append a message to a thread.
|
* Append a message to a thread.
|
||||||
|
*
|
||||||
|
* ``turn_id`` is the per-turn correlation id streamed by the backend
|
||||||
|
* via ``data-turn-info``. Persisting it lets later edits locate the
|
||||||
|
* matching LangGraph checkpoint without HumanMessage scanning. Older
|
||||||
|
* callers can still omit it for back-compat.
|
||||||
*/
|
*/
|
||||||
export async function appendMessage(
|
export async function appendMessage(
|
||||||
threadId: number,
|
threadId: number,
|
||||||
message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown }
|
message: {
|
||||||
|
role: "user" | "assistant" | "system";
|
||||||
|
content: unknown;
|
||||||
|
token_usage?: unknown;
|
||||||
|
turn_id?: string | null;
|
||||||
|
}
|
||||||
): Promise<MessageRecord> {
|
): Promise<MessageRecord> {
|
||||||
return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, {
|
return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, {
|
||||||
body: message,
|
body: message,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue