mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 06:12:40 +02:00
Merge remote-tracking branch 'upstream/dev' into feature/multi-agent
This commit is contained in:
commit
5d3b8b9ca9
83 changed files with 10514 additions and 638 deletions
8
.github/workflows/desktop-release.yml
vendored
8
.github/workflows/desktop-release.yml
vendored
|
|
@ -136,6 +136,14 @@ jobs:
|
|||
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
||||
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
||||
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
|
||||
# macOS Developer ID signing + notarization. Only the macos-latest runner
|
||||
# consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either
|
||||
# a file path or a base64-encoded .p12 blob — electron-builder auto-detects.
|
||||
CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }}
|
||||
CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }}
|
||||
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
||||
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
||||
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
||||
|
|
|
|||
|
|
@ -285,6 +285,14 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||
# Schema migrations 135/136 ship unconditionally because they are
|
||||
# forward-compatible.
|
||||
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||
|
||||
# Plugins
|
||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||
# Comma-separated allowlist of plugin entry-point names
|
||||
|
|
|
|||
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""134_relax_revision_fks
|
||||
|
||||
Revision ID: 134
|
||||
Revises: 133
|
||||
Create Date: 2026-04-29
|
||||
|
||||
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
|
||||
revisions survive the deletes they describe.
|
||||
|
||||
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
|
||||
hard-deleting a document via the ``rm`` tool (and likewise a
|
||||
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
|
||||
the snapshot row is wiped at the exact moment we need it most — revert
|
||||
then has nothing to read and the operation becomes irreversible.
|
||||
|
||||
Migration:
|
||||
|
||||
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
|
||||
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
|
||||
* ``folder_revisions.folder_id``: same treatment.
|
||||
|
||||
The ``search_space_id`` FK on both tables is left unchanged (still
|
||||
``ON DELETE CASCADE``). When a search space is deleted, all documents,
|
||||
folders, AND their revisions go together — that's the correct teardown
|
||||
story.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "134"
|
||||
down_revision: str | None = "133"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _fk_name(bind, table: str, column: str) -> str | None:
|
||||
"""Return the (single) FK constraint name on ``table.column``, if any."""
|
||||
inspector = inspect(bind)
|
||||
for fk in inspector.get_foreign_keys(table):
|
||||
cols = fk.get("constrained_columns") or []
|
||||
if cols == [column]:
|
||||
return fk.get("name")
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# --- document_revisions.document_id -> nullable + SET NULL ---------------
|
||||
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||
if fk_name:
|
||||
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||
op.alter_column(
|
||||
"document_revisions",
|
||||
"document_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_revisions_document_id_fkey",
|
||||
"document_revisions",
|
||||
"documents",
|
||||
["document_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
|
||||
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||
if fk_name:
|
||||
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||
op.alter_column(
|
||||
"folder_revisions",
|
||||
"folder_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"folder_revisions_folder_id_fkey",
|
||||
"folder_revisions",
|
||||
"folders",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
|
||||
# (any revision whose parent doc/folder has already been deleted).
|
||||
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
|
||||
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
|
||||
|
||||
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
|
||||
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||
if fk_name:
|
||||
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||
op.alter_column(
|
||||
"document_revisions",
|
||||
"document_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_revisions_document_id_fkey",
|
||||
"document_revisions",
|
||||
"documents",
|
||||
["document_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
|
||||
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||
if fk_name:
|
||||
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||
op.alter_column(
|
||||
"folder_revisions",
|
||||
"folder_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"folder_revisions_folder_id_fkey",
|
||||
"folder_revisions",
|
||||
"folders",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""135_action_log_correlation_ids
|
||||
|
||||
Revision ID: 135
|
||||
Revises: 134
|
||||
Create Date: 2026-04-29
|
||||
|
||||
Action-log correlation-id cleanup.
|
||||
|
||||
Background
|
||||
----------
|
||||
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
|
||||
the LangChain ``tool_call.id`` into that column today (see
|
||||
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
|
||||
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
|
||||
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
|
||||
never persisted.
|
||||
|
||||
This migration introduces two new, correctly-named columns:
|
||||
|
||||
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
|
||||
* ``chat_turn_id`` (the per-turn correlation id from
|
||||
``configurable.turn_id`` — used by the per-turn ``revert-turn`` route).
|
||||
|
||||
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
|
||||
existing joins keep working. The old ``turn_id`` column is left in place
|
||||
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
|
||||
keeps writing it (= ``tool_call_id``) for the same reason.
|
||||
|
||||
Indexes
|
||||
-------
|
||||
|
||||
* ``ix_agent_action_log_tool_call_id`` — required by
|
||||
``_find_action_ids_batch`` (was on ``turn_id``).
|
||||
* ``ix_agent_action_log_chat_turn_id`` — required by the
|
||||
``revert-turn/{chat_turn_id}`` query.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "135"
|
||||
down_revision: str | None = "134"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_action_log",
|
||||
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_action_log",
|
||||
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_agent_action_log_tool_call_id",
|
||||
"agent_action_log",
|
||||
["tool_call_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_action_log_chat_turn_id",
|
||||
"agent_action_log",
|
||||
["chat_turn_id"],
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
|
||||
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
|
||||
op.drop_column("agent_action_log", "chat_turn_id")
|
||||
op.drop_column("agent_action_log", "tool_call_id")
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""136_new_chat_message_turn_id
|
||||
|
||||
Revision ID: 136
|
||||
Revises: 135
|
||||
Create Date: 2026-04-29
|
||||
|
||||
Persist the per-turn correlation id on each chat message.
|
||||
|
||||
Background
|
||||
----------
|
||||
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
|
||||
in checkpoint metadata (see
|
||||
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
|
||||
support edit-from-arbitrary-position, the regenerate route needs to map
|
||||
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
|
||||
this column the mapping doesn't exist anywhere, so regenerate would
|
||||
have to hardcode the "last 2 messages" rewind heuristic.
|
||||
|
||||
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
|
||||
plus an index. Legacy rows have NULL — the regenerate route degrades
|
||||
gracefully to the reload-last-two heuristic for those.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "136"
|
||||
down_revision: str | None = "135"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"new_chat_messages",
|
||||
sa.Column("turn_id", sa.String(length=64), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_new_chat_messages_turn_id",
|
||||
"new_chat_messages",
|
||||
["turn_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
|
||||
op.drop_column("new_chat_messages", "turn_id")
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""137_unique_reverse_of_in_action_log
|
||||
|
||||
Revision ID: 137
|
||||
Revises: 136
|
||||
Create Date: 2026-04-29
|
||||
|
||||
Protect ``agent_action_log.reverse_of`` against double inserts. Two
|
||||
concurrent revert calls (single-action route + the per-turn batch
|
||||
route, or two batch routes racing) both pass the
|
||||
``_was_already_reverted`` SELECT and both insert their own
|
||||
``_revert:*`` rows. The application-level idempotency check is racy
|
||||
because there's no DB constraint backing it.
|
||||
|
||||
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
|
||||
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
|
||||
``IntegrityError`` and the route can translate it to ``"already_reverted"``
|
||||
deterministically.
|
||||
|
||||
The plain ``UniqueConstraint`` flavour can't be used because most
|
||||
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
|
||||
and Postgres does treat NULL as distinct in unique indexes — but a
|
||||
partial index is the cleanest expression of intent and works even on
|
||||
older Postgres releases that distinguish NULL handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "137"
|
||||
down_revision: str | None = "136"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
_INDEX_NAME = "ux_agent_action_log_reverse_of"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Defensively de-dup any pre-existing double-revert rows before
|
||||
# adding the unique index. Keeps the OLDEST row (smallest id) and
|
||||
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
|
||||
# trail but no longer claim to be the canonical revert. We do NOT
|
||||
# delete them — operators can still inspect them via /actions.
|
||||
op.execute(
|
||||
"""
|
||||
WITH dups AS (
|
||||
SELECT id,
|
||||
reverse_of,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY reverse_of ORDER BY id ASC
|
||||
) AS rn
|
||||
FROM agent_action_log
|
||||
WHERE reverse_of IS NOT NULL
|
||||
)
|
||||
UPDATE agent_action_log
|
||||
SET reverse_of = NULL
|
||||
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
_INDEX_NAME,
|
||||
"agent_action_log",
|
||||
["reverse_of"],
|
||||
unique=True,
|
||||
postgresql_where="reverse_of IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(_INDEX_NAME, table_name="agent_action_log")
|
||||
|
|
@ -728,7 +728,8 @@ def _build_compiled_agent_blocking(
|
|||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
# Tools owned by the standard deepagents middleware stack.
|
||||
# Tools owned by the standard deepagents middleware stack and the
|
||||
# SurfSense filesystem extension.
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
|
|
@ -739,6 +740,14 @@ def _build_compiled_agent_blocking(
|
|||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
|
|
@ -767,25 +776,51 @@ def _build_compiled_agent_blocking(
|
|||
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||
# reject decision into innocent calls.
|
||||
# 2. ``connector_synthesized`` — deny rules for tools whose required
|
||||
# connector is not connected to this space. Overrides #1.
|
||||
# 3. (future) user-defined rules from ``agent_permission_rules`` table
|
||||
# via the Agent Permissions UI. Loaded last so they override both.
|
||||
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
|
||||
# the agent is operating against the user's real disk. Cloud mode
|
||||
# has full revision-based revert via ``revert_service``, but
|
||||
# desktop mode hits disk immediately with no undo, so an
|
||||
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
|
||||
# ``write_file`` is unrecoverable. This layer is forced on in
|
||||
# desktop mode regardless of ``enable_permission`` because the
|
||||
# safety net is non-negotiable.
|
||||
# 3. ``connector_synthesized`` — deny rules for tools whose required
|
||||
# connector is not connected to this space. Overrides #1/#2.
|
||||
# 4. (future) user-defined rules from ``agent_permission_rules`` table
|
||||
# via the Agent Permissions UI. Loaded last so they override all.
|
||||
permission_mw: PermissionMiddleware | None = None
|
||||
if flags.enable_permission and not flags.disable_new_agent_stack:
|
||||
synthesized = _synthesize_connector_deny_rules(
|
||||
available_connectors=available_connectors,
|
||||
enabled_tool_names={t.name for t in tools},
|
||||
)
|
||||
permission_mw = PermissionMiddleware(
|
||||
rulesets=[
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||
# Build the middleware whenever it has work to do: either the user
|
||||
# opted into the rule engine, OR we're in desktop mode and need the
|
||||
# safety rules unconditionally.
|
||||
if permission_enabled or is_desktop_fs:
|
||||
rulesets: list[Ruleset] = [
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
]
|
||||
if is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
Ruleset(rules=synthesized, origin="connector_synthesized"),
|
||||
],
|
||||
)
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
if permission_enabled:
|
||||
synthesized = _synthesize_connector_deny_rules(
|
||||
available_connectors=available_connectors,
|
||||
enabled_tool_names={t.name for t in tools},
|
||||
)
|
||||
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||
permission_mw = PermissionMiddleware(rulesets=rulesets)
|
||||
|
||||
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
||||
# table is migrated. When enabled, persists one row per tool call
|
||||
|
|
@ -942,6 +977,7 @@ def _build_compiled_agent_blocking(
|
|||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
|
|||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
|
|
@ -86,6 +87,15 @@ class AgentFeatureFlags:
|
|||
False # Backend ships before UI; route returns 503 until this flips
|
||||
)
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured
|
||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||
# ship unconditionally because they're forward-compatible.
|
||||
enable_stream_parity_v2: bool = False
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
|
|
@ -139,6 +149,10 @@ class AgentFeatureFlags:
|
|||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||
# Streaming parity v2
|
||||
enable_stream_parity_v2=_env_bool(
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
||||
),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
|
|
|
|||
|
|
@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
|||
|
||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||
|
|
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
|
|||
)
|
||||
|
||||
|
||||
class PendingMove(TypedDict):
|
||||
"""A staged move_file operation pending end-of-turn commit."""
|
||||
class PendingMove(TypedDict, total=False):
|
||||
"""A staged move_file operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||
written before the snapshot/revert pipeline was wired up; new entries
|
||||
always include it so the persistence body can resolve an action_id.
|
||||
"""
|
||||
|
||||
source: str
|
||||
dest: str
|
||||
overwrite: bool
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class PendingDelete(TypedDict, total=False):
|
||||
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is required for new entries (it's the binding key used
|
||||
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||
``total=False`` only to tolerate older checkpoint payloads.
|
||||
"""
|
||||
|
||||
path: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class KbPriorityEntry(TypedDict, total=False):
|
||||
|
|
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||
|
||||
staged_dir_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||
|
||||
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||
"""
|
||||
|
||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||
|
||||
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||
|
||||
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||
``tool_call_id`` per occurrence so snapshot binding works).
|
||||
"""
|
||||
|
||||
pending_dir_deletes: NotRequired[
|
||||
Annotated[list[PendingDelete], _list_append_reducer]
|
||||
]
|
||||
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||
|
||||
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||
for) before issuing the DELETE.
|
||||
"""
|
||||
|
||||
doc_id_by_path: NotRequired[
|
||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
|
|
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||
|
||||
dirty_path_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||
|
||||
The persistence body coalesces multiple writes/edits to the same path
|
||||
into one snapshot per turn. This map captures the most-recent
|
||||
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||
to the latest action_id (the one the user is most likely to revert).
|
||||
"""
|
||||
|
||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||
|
||||
|
|
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
__all__ = [
|
||||
"KbAnonDoc",
|
||||
"KbPriorityEntry",
|
||||
"PendingDelete",
|
||||
"PendingMove",
|
||||
"SurfSenseFilesystemState",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
|
|
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
result=result,
|
||||
)
|
||||
|
||||
tool_call_id = _resolve_tool_call_id(request)
|
||||
chat_turn_id = _resolve_chat_turn_id(request)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
turn_id=_resolve_turn_id(request),
|
||||
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||
# kept for one release for safe rollback. New consumers
|
||||
# should read ``tool_call_id`` directly.
|
||||
turn_id=tool_call_id,
|
||||
tool_call_id=tool_call_id,
|
||||
chat_turn_id=chat_turn_id,
|
||||
message_id=_resolve_message_id(request),
|
||||
tool_name=tool_name,
|
||||
args=args_payload,
|
||||
|
|
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
async with shielded_async_session() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
row_id = int(row.id) if row.id is not None else None
|
||||
row_created_at = row.created_at
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"ActionLogMiddleware failed to persist action log row",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Surface a side-channel SSE event so the chat tool card can
|
||||
# render a Revert button immediately after the row is durable.
|
||||
# ``stream_new_chat`` translates this into a
|
||||
# ``data-action-log`` SSE event. We DO NOT include the
|
||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"action_log",
|
||||
{
|
||||
"id": row_id,
|
||||
"lc_tool_call_id": tool_call_id,
|
||||
"chat_turn_id": chat_turn_id,
|
||||
"tool_name": tool_name,
|
||||
"reversible": bool(reversible),
|
||||
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||
"created_at": row_created_at.isoformat()
|
||||
if row_created_at
|
||||
else None,
|
||||
"error": error_payload is not None,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ActionLogMiddleware failed to dispatch action_log event",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _render_reverse(
|
||||
self,
|
||||
|
|
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
|||
}
|
||||
|
||||
|
||||
def _resolve_turn_id(request: Any) -> str | None:
|
||||
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||
try:
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
|
|
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
# Deprecated alias kept for one release. Old callers and tests treated
|
||||
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||
_resolve_turn_id = _resolve_tool_call_id
|
||||
|
||||
|
||||
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||
|
||||
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||
"""
|
||||
try:
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
configurable = config.get("configurable")
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
value = configurable.get("turn_id")
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_turn_id(request)
|
||||
return _resolve_tool_call_id(request)
|
||||
|
||||
|
||||
def _resolve_result_id(result: Any) -> str | None:
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
|
|||
- cd(path): change the current working directory.
|
||||
- pwd(): print the current working directory.
|
||||
- move_file(source, dest): move/rename a file under `/documents/`.
|
||||
- rm(path): delete a single file under `/documents/` (no `-r`).
|
||||
- rmdir(path): delete an empty directory under `/documents/`.
|
||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||
|
||||
## Persistence Rules
|
||||
|
|
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
|
|||
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
||||
prefix for any scratch/working content you do NOT want saved.
|
||||
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
||||
- mkdir/move_file are staged this turn and committed at end of turn alongside
|
||||
any new/edited documents.
|
||||
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
|
||||
turn alongside any new/edited documents. Snapshot/revert is enabled
|
||||
for every destructive operation when action logging is on.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
|
|
@ -176,6 +179,8 @@ directory (`cwd`).
|
|||
- cd(path): change the current working directory.
|
||||
- pwd(): print the current working directory.
|
||||
- move_file(source, dest): move/rename a file.
|
||||
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
|
||||
- rmdir(path): delete an empty directory from disk. NOT reversible.
|
||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||
|
||||
## Workflow Tips
|
||||
|
|
@ -184,6 +189,8 @@ directory (`cwd`).
|
|||
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
||||
brute-force directory traversal.
|
||||
- Cross-mount moves are not supported.
|
||||
- Desktop deletes hit disk immediately and cannot be undone via the
|
||||
agent's revert flow — confirm before calling `rm`/`rmdir`.
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -355,6 +362,42 @@ Notes:
|
|||
- Parent folders are created as needed.
|
||||
"""
|
||||
|
||||
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
|
||||
|
||||
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
|
||||
for end-of-turn commit; the row is removed only after the agent's turn
|
||||
finishes successfully.
|
||||
|
||||
Args:
|
||||
- path: absolute or relative file path. Cannot point at a directory — use
|
||||
`rmdir` for empty folders. Cannot target the root or `/documents`.
|
||||
|
||||
Notes:
|
||||
- The action is reversible via the per-action revert flow when action
|
||||
logging is enabled.
|
||||
- The anonymous uploaded document is read-only and cannot be deleted.
|
||||
"""
|
||||
|
||||
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
|
||||
|
||||
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||
deletion (`rm -r`) is intentionally NOT supported — clear contents with
|
||||
`rm` first.
|
||||
|
||||
Args:
|
||||
- path: absolute or relative directory path. Cannot target the root,
|
||||
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
|
||||
move out first).
|
||||
|
||||
Notes:
|
||||
- Emptiness is evaluated against the post-staged view, so a same-turn
|
||||
`rm /a/x.md` followed by `rmdir /a` is fine.
|
||||
- If the directory was added in this same turn via `mkdir` and never
|
||||
committed, the staged mkdir is dropped instead of issuing a delete.
|
||||
- The action is reversible via the per-action revert flow when action
|
||||
logging is enabled.
|
||||
"""
|
||||
|
||||
# --- desktop-only ----------------------------------------------------------
|
||||
|
||||
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||
|
|
@ -421,6 +464,28 @@ Notes:
|
|||
- Parent folders are created as needed.
|
||||
"""
|
||||
|
||||
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
|
||||
|
||||
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
|
||||
disk immediately. Desktop deletes are NOT reversible via the agent's
|
||||
revert flow.
|
||||
|
||||
Args:
|
||||
- path: absolute mount-prefixed file path. Cannot point at a directory —
|
||||
use `rmdir` for empty folders.
|
||||
"""
|
||||
|
||||
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
|
||||
|
||||
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||
deletion is NOT supported. The deletion hits disk immediately and is
|
||||
NOT reversible via the agent's revert flow.
|
||||
|
||||
Args:
|
||||
- path: absolute mount-prefixed directory path. Cannot target the mount
|
||||
root or any directory containing files/subfolders.
|
||||
"""
|
||||
|
||||
|
||||
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||
"""Pick the active-mode description for every filesystem tool."""
|
||||
|
|
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
|||
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
|
||||
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
|
||||
}
|
||||
return {
|
||||
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
||||
|
|
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
|||
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
|
||||
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
|
|||
return path.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
def _is_ancestor_of(candidate: str, target: str) -> bool:
|
||||
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
|
||||
|
||||
``target`` itself is NOT considered an ancestor (use equality for that).
|
||||
Both paths are assumed to be canonicalised, absolute, and free of
|
||||
trailing slashes (except the root ``/``).
|
||||
"""
|
||||
if not candidate.startswith("/") or not target.startswith("/"):
|
||||
return False
|
||||
if candidate == target:
|
||||
return False
|
||||
prefix = candidate.rstrip("/") + "/"
|
||||
return target.startswith(prefix)
|
||||
|
||||
|
||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
||||
|
||||
|
|
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
self.tools.append(self._create_cd_tool())
|
||||
self.tools.append(self._create_pwd_tool())
|
||||
self.tools.append(self._create_move_file_tool())
|
||||
self.tools.append(self._create_rm_tool())
|
||||
self.tools.append(self._create_rmdir_tool())
|
||||
self.tools.append(self._create_list_tree_tool())
|
||||
if self._sandbox_available:
|
||||
self.tools.append(self._create_execute_code_tool())
|
||||
|
|
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
}
|
||||
if self._is_cloud():
|
||||
update["dirty_paths"] = [path]
|
||||
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||
return Command(update=update)
|
||||
|
||||
def sync_write_file(
|
||||
|
|
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
}
|
||||
if self._is_cloud():
|
||||
update["dirty_paths"] = [path]
|
||||
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||
if doc_id_to_attach is not None:
|
||||
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
||||
return Command(update=update)
|
||||
|
|
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
return Command(
|
||||
update={
|
||||
"staged_dirs": [validated],
|
||||
"staged_dir_tool_calls": {
|
||||
validated: runtime.tool_call_id,
|
||||
},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
|
|
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
||||
update: dict[str, Any] = {
|
||||
"files": files_update,
|
||||
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
|
||||
"pending_moves": [
|
||||
{
|
||||
"source": source,
|
||||
"dest": dest,
|
||||
"overwrite": False,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
|
|
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
update["dirty_paths"] = new_dirty
|
||||
return Command(update=update)
|
||||
|
||||
# ------------------------------------------------------------------ tool: rm
|
||||
|
||||
def _create_rm_tool(self) -> BaseTool:
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
async def async_rm(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path to the file to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
if not path or not path.strip():
|
||||
return "Error: path is required."
|
||||
|
||||
target = self._resolve_relative(path, runtime)
|
||||
try:
|
||||
validated = validate_path(target)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
if self._is_cloud():
|
||||
if validated in ("/", DOCUMENTS_ROOT):
|
||||
return f"Error: refusing to rm '{validated}'."
|
||||
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud rm must target a path under /documents/ "
|
||||
f"(got '{validated}')."
|
||||
)
|
||||
|
||||
anon = runtime.state.get("kb_anon_doc") or {}
|
||||
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
|
||||
return "Error: the anonymous uploaded document is read-only."
|
||||
|
||||
# Refuse if the path looks like a directory.
|
||||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
if validated in staged_dirs:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
)
|
||||
pending_dir_deletes = list(
|
||||
runtime.state.get("pending_dir_deletes") or []
|
||||
)
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
):
|
||||
return f"Error: '{validated}' is already queued for rmdir."
|
||||
|
||||
backend = self._get_backend(runtime)
|
||||
if isinstance(backend, KBPostgresBackend):
|
||||
# Detect "is a directory" via `ls`: if the path lists
|
||||
# children we know it's a folder. Otherwise we still
|
||||
# need to confirm it's a real file before staging.
|
||||
children = await backend.als_info(validated)
|
||||
if children:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
)
|
||||
|
||||
# Already queued for delete this turn?
|
||||
pending_deletes = list(runtime.state.get("pending_deletes") or [])
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_deletes
|
||||
):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
# Resolve doc_id (best-effort): file in state or DB.
|
||||
files_state = runtime.state.get("files") or {}
|
||||
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
|
||||
resolved_doc_id: int | None = doc_id_by_path.get(validated)
|
||||
if (
|
||||
validated not in files_state
|
||||
and resolved_doc_id is None
|
||||
and isinstance(backend, KBPostgresBackend)
|
||||
):
|
||||
loaded = await backend._load_file_data(validated)
|
||||
if loaded is None:
|
||||
return f"Error: file '{validated}' not found."
|
||||
_, resolved_doc_id = loaded
|
||||
|
||||
files_update: dict[str, Any] = {validated: None}
|
||||
update: dict[str, Any] = {
|
||||
"pending_deletes": [
|
||||
{
|
||||
"path": validated,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"files": files_update,
|
||||
"doc_id_by_path": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged delete of '{validated}' (will commit at "
|
||||
"end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
# Drop the path from dirty_paths so a same-turn write+rm
|
||||
# doesn't recreate the doc at commit time.
|
||||
dirty_paths = list(runtime.state.get("dirty_paths") or [])
|
||||
if validated in dirty_paths:
|
||||
new_dirty: list[Any] = [_CLEAR]
|
||||
for entry in dirty_paths:
|
||||
if entry != validated:
|
||||
new_dirty.append(entry)
|
||||
update["dirty_paths"] = new_dirty
|
||||
update["dirty_path_tool_calls"] = {validated: None}
|
||||
|
||||
return Command(update=update)
|
||||
|
||||
# Desktop mode — hit disk immediately.
|
||||
backend = self._get_backend(runtime)
|
||||
adelete = getattr(backend, "adelete_file", None)
|
||||
if not callable(adelete):
|
||||
return "Error: rm is not supported by the active backend."
|
||||
res: WriteResult = await adelete(validated)
|
||||
if res.error:
|
||||
return res.error
|
||||
update_desktop: dict[str, Any] = {
|
||||
"files": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Deleted file '{res.path or validated}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
return Command(update=update_desktop)
|
||||
|
||||
def sync_rm(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path to the file to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
return self._run_async_blocking(async_rm(path, runtime))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="rm",
|
||||
description=tool_description,
|
||||
func=sync_rm,
|
||||
coroutine=async_rm,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ tool: rmdir
|
||||
|
||||
def _create_rmdir_tool(self) -> BaseTool:
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
async def async_rmdir(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path of the empty directory to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
if not path or not path.strip():
|
||||
return "Error: path is required."
|
||||
|
||||
target = self._resolve_relative(path, runtime)
|
||||
try:
|
||||
validated = validate_path(target)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
if self._is_cloud():
|
||||
if validated in ("/", DOCUMENTS_ROOT):
|
||||
return f"Error: refusing to rmdir '{validated}'."
|
||||
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud rmdir must target a path under /documents/ "
|
||||
f"(got '{validated}')."
|
||||
)
|
||||
|
||||
cwd = self._current_cwd(runtime)
|
||||
if validated == cwd or _is_ancestor_of(validated, cwd):
|
||||
return (
|
||||
f"Error: cannot rmdir '{validated}' because the current "
|
||||
"cwd is at or under it. cd out first."
|
||||
)
|
||||
|
||||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
pending_dir_deletes = list(
|
||||
runtime.state.get("pending_dir_deletes") or []
|
||||
)
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
backend = self._get_backend(runtime)
|
||||
|
||||
# The path must currently exist either in DB folder paths or
|
||||
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
|
||||
# already accounts for pending deletes/moves) to evaluate
|
||||
# both existence and emptiness against the post-staged view.
|
||||
exists_in_staged = validated in staged_dirs
|
||||
children: list[Any] = []
|
||||
if isinstance(backend, KBPostgresBackend):
|
||||
children = list(await backend.als_info(validated))
|
||||
|
||||
# Detect "is a file" — if als_info returns no children but
|
||||
# the path is actually a file, we should reject. We use
|
||||
# _load_file_data to disambiguate file vs missing folder.
|
||||
if (
|
||||
isinstance(backend, KBPostgresBackend)
|
||||
and not children
|
||||
and not exists_in_staged
|
||||
):
|
||||
loaded = await backend._load_file_data(validated)
|
||||
if loaded is not None:
|
||||
return (
|
||||
f"Error: '{validated}' is a file. Use rm to delete files."
|
||||
)
|
||||
# Confirm folder exists in DB by checking the parent listing.
|
||||
parent = posixpath.dirname(validated) or "/"
|
||||
parent_listing = await backend.als_info(parent)
|
||||
parent_has_dir = any(
|
||||
info.get("path") == validated and info.get("is_dir")
|
||||
for info in parent_listing
|
||||
)
|
||||
if not parent_has_dir:
|
||||
return f"Error: directory '{validated}' not found."
|
||||
|
||||
if children:
|
||||
return (
|
||||
f"Error: directory '{validated}' is not empty. "
|
||||
"Remove contents first."
|
||||
)
|
||||
|
||||
# Same-turn mkdir un-stage: drop the staged_dirs entry
|
||||
# entirely and skip queuing a DB delete (nothing was ever
|
||||
# committed).
|
||||
if exists_in_staged:
|
||||
rest = [d for d in staged_dirs if d != validated]
|
||||
return Command(
|
||||
update={
|
||||
"staged_dirs": [_CLEAR, *rest],
|
||||
"staged_dir_tool_calls": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(f"Un-staged directory '{validated}'."),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"pending_dir_deletes": [
|
||||
{
|
||||
"path": validated,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged rmdir of '{validated}' (will commit "
|
||||
"at end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Desktop mode — hit disk immediately.
|
||||
backend = self._get_backend(runtime)
|
||||
armdir = getattr(backend, "armdir", None)
|
||||
if not callable(armdir):
|
||||
return "Error: rmdir is not supported by the active backend."
|
||||
res: WriteResult = await armdir(validated)
|
||||
if res.error:
|
||||
return res.error
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Deleted directory '{res.path or validated}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def sync_rmdir(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path of the empty directory to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
return self._run_async_blocking(async_rmdir(path, runtime))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="rmdir",
|
||||
description=tool_description,
|
||||
func=sync_rmdir,
|
||||
coroutine=async_rmdir,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ tool: list_tree
|
||||
|
||||
def _create_list_tree_tool(self) -> BaseTool:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_moves") or [])
|
||||
|
||||
def _pending_deletes(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_deletes") or [])
|
||||
|
||||
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_dir_deletes") or [])
|
||||
|
||||
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||
anon = self.state.get("kb_anon_doc")
|
||||
return anon if isinstance(anon, dict) else None
|
||||
|
|
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
|
|||
return path
|
||||
return path.rstrip("/") if path != "/" else path
|
||||
|
||||
def _moved_view_paths(
|
||||
def _pending_filesystem_view(
|
||||
self,
|
||||
existing: dict[str, dict[str, Any]],
|
||||
) -> tuple[set[str], dict[str, str]]:
|
||||
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
||||
) -> tuple[set[str], dict[str, str], set[str]]:
|
||||
"""Compute removed/aliased/dir-suppressed paths from staged ops.
|
||||
|
||||
Removed paths should disappear from listings; ``alias[source] = dest``
|
||||
means a virtual entry should appear at ``dest`` even if no DB row is
|
||||
yet there.
|
||||
Returns ``(removed, alias, deleted_dirs)`` where:
|
||||
|
||||
* ``removed`` — paths to drop from listings (sources of pending moves
|
||||
AND paths queued for ``rm``).
|
||||
* ``alias`` — ``{source: dest}`` for pending moves; the dest should
|
||||
appear as a virtual entry even when no DB row is at that path yet.
|
||||
* ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire
|
||||
subtree (descendants) is suppressed from listings/glob/grep.
|
||||
|
||||
Entries in ``existing`` (the ``files`` state cache) keyed by a
|
||||
removed path are popped so a same-turn delete-after-write doesn't
|
||||
leave a stale virtual file in listings.
|
||||
"""
|
||||
removed: set[str] = set()
|
||||
alias: dict[str, str] = {}
|
||||
deleted_dirs: set[str] = set()
|
||||
for move in self._pending_moves():
|
||||
src = move.get("source")
|
||||
dst = move.get("dest")
|
||||
|
|
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
|
|||
removed.add(src)
|
||||
alias[src] = dst
|
||||
existing.pop(src, None)
|
||||
return removed, alias
|
||||
for entry in self._pending_deletes():
|
||||
path = entry.get("path") if isinstance(entry, dict) else None
|
||||
if not path:
|
||||
continue
|
||||
removed.add(path)
|
||||
existing.pop(path, None)
|
||||
for entry in self._pending_dir_deletes():
|
||||
path = entry.get("path") if isinstance(entry, dict) else None
|
||||
if not path:
|
||||
continue
|
||||
deleted_dirs.add(path)
|
||||
return removed, alias, deleted_dirs
|
||||
|
||||
@staticmethod
|
||||
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
|
||||
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
|
||||
return any(path == d or _is_under(path, d) for d in deleted_dirs)
|
||||
|
||||
# ------------------------------------------------------------------ ls/read
|
||||
|
||||
|
|
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
seen.add(anon_path)
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, moved_alias = self._moved_view_paths(files)
|
||||
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
|
||||
|
||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||
try:
|
||||
|
|
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
|
||||
for info in db_infos:
|
||||
p = info.get("path", "")
|
||||
if not p or p in seen or p in moved_removed:
|
||||
if (
|
||||
not p
|
||||
or p in seen
|
||||
or p in moved_removed
|
||||
or self._is_dir_suppressed(p, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
infos.append(info)
|
||||
seen.add(p)
|
||||
|
|
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
if src not in seen:
|
||||
if not _is_under(dst, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(dst, deleted_dirs):
|
||||
continue
|
||||
rel = (
|
||||
dst[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
|
|
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
continue
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||
continue
|
||||
rel = (
|
||||
staged[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
|
|
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for sub in sorted(subdir_paths):
|
||||
if sub in seen:
|
||||
continue
|
||||
if self._is_dir_suppressed(sub, deleted_dirs):
|
||||
continue
|
||||
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||
seen.add(sub)
|
||||
|
||||
for path_key, fd in files.items():
|
||||
if not isinstance(path_key, str) or path_key in seen:
|
||||
continue
|
||||
# Tombstones (None values) are deletion markers from `rm`. The
|
||||
# deepagents reducer normally pops them, but a stale tombstone
|
||||
# surviving a checkpoint must NOT be reported as a child here —
|
||||
# otherwise rmdir mistakenly sees the deleted file as content.
|
||||
if fd is None:
|
||||
continue
|
||||
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||
continue
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if normalized == "/":
|
||||
rel = path_key.lstrip("/")
|
||||
else:
|
||||
|
|
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
seen: set[str] = set()
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
regex = re.compile(fnmatch.translate(pattern))
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
|
|
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
|
|||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in seen or candidate in moved_removed:
|
||||
if (
|
||||
candidate in seen
|
||||
or candidate in moved_removed
|
||||
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
matches: list[GrepMatch] = []
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
|
|
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
|
|||
)
|
||||
for doc_id, chunk_id, content in chunk_buffer:
|
||||
candidate = doc_id_to_path.get(doc_id)
|
||||
if not candidate or candidate in moved_removed:
|
||||
if (
|
||||
not candidate
|
||||
or candidate in moved_removed
|
||||
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
return {"entries": [], "truncated": False}
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
anon = self._kb_anon_doc()
|
||||
anon_path = str(anon.get("path") or "") if anon else ""
|
||||
|
||||
|
|
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||
if not _is_under(fpath, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(fpath, deleted_dirs):
|
||||
continue
|
||||
depth = _depth_of(fpath)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
|
|
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for staged in self._staged_dirs():
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||
continue
|
||||
depth = _depth_of(staged)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
|
|
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
|
|||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in moved_removed:
|
||||
if candidate in moved_removed or self._is_dir_suppressed(
|
||||
candidate, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
|
|||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if any(e["path"] == path_key for e in entries):
|
||||
continue
|
||||
if not (
|
||||
|
|
|
|||
|
|
@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||
# infer emptiness from indentation alone.
|
||||
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
|
|
@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
lines.append(f"{indent}{display}/")
|
||||
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||
lines.append(f"{indent}{display}/ (empty)")
|
||||
else:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
|
|
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
@staticmethod
|
||||
def _compute_non_empty_folders(
|
||||
folder_paths: list[str], doc_paths: list[str]
|
||||
) -> set[str]:
|
||||
"""Return the set of folder paths that contain at least one descendant.
|
||||
|
||||
A folder is "non-empty" if any document path or any other folder path
|
||||
is strictly under it. Documents propagate emptiness up to every
|
||||
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||
"""
|
||||
non_empty: set[str] = set()
|
||||
folder_set = set(folder_paths)
|
||||
|
||||
for doc_path in doc_paths:
|
||||
parent = doc_path.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT:
|
||||
if parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
for child in folder_paths:
|
||||
parent = child.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
return non_empty
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -360,6 +360,74 @@ class LocalFolderBackend:
|
|||
self.move, source_path, destination_path, overwrite
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
"""Hard-delete a single file under root.
|
||||
|
||||
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||
out of scope.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||
with self._lock_for(file_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||
if path.is_dir():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: '{file_path}' is a directory. "
|
||||
"Use rmdir for empty directories."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(
|
||||
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||
)
|
||||
return WriteResult(path=file_path, files_update=None)
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
"""Hard-delete an empty directory under root.
|
||||
|
||||
Refuses files, root, missing paths, and non-empty directories.
|
||||
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||
clearer for the agent.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(dir_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||
with self._lock_for(dir_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||
if not path.is_dir():
|
||||
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||
try:
|
||||
next(path.iterdir())
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: directory '{dir_path}' is not empty. "
|
||||
"Remove its contents first."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.rmdir(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||
return WriteResult(path=dir_path, files_update=None)
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
|
|||
|
|
@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
|
|||
overwrite,
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(dir_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
if local_path == "/":
|
||||
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
|
|||
|
|
@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
|
|||
return {
|
||||
"cwd": "/documents",
|
||||
"staged_dirs": [],
|
||||
"staged_dir_tool_calls": {},
|
||||
"pending_moves": [],
|
||||
"pending_deletes": [],
|
||||
"pending_dir_deletes": [],
|
||||
"doc_id_by_path": {},
|
||||
"dirty_paths": [],
|
||||
"dirty_path_tool_calls": {},
|
||||
"kb_priority": [],
|
||||
"kb_matched_chunk_ids": {},
|
||||
"kb_anon_doc": None,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
|||
"write_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"update_memory",
|
||||
"update_memory_team",
|
||||
"update_memory_private",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,35 @@ from langgraph.types import interrupt
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tools that mirror the safety profile of ``write_file`` against the
|
||||
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
||||
# with no external visibility (drafts aren't sent; new files aren't shared
|
||||
# unless the user shares them later). These are auto-approved by default
|
||||
# so the agent can compose drafts and seed scratch files without a popup
|
||||
# on every call.
|
||||
#
|
||||
# Members of this set still call ``request_approval`` exactly as before;
|
||||
# the function returns immediately with ``decision_type="auto_approved"``
|
||||
# and the original params untouched. This preserves the call-site shape
|
||||
# (logging, metadata fetching, account fallbacks) so the only behavior
|
||||
# change is "no interrupt fires".
|
||||
#
|
||||
# To re-enable prompting, the future per-search-space rules table
|
||||
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
|
||||
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
"create_dropbox_file",
|
||||
"create_onedrive_file",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class HITLResult:
|
||||
"""Outcome of a human-in-the-loop approval request."""
|
||||
|
|
@ -119,6 +148,19 @@ def request_approval(
|
|||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||
|
||||
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||
# Default policy: low-stakes creation tools (drafts + new-file
|
||||
# creates) skip HITL because they're as recoverable as a local
|
||||
# ``write_file`` against the SurfSense KB. The user can still
|
||||
# delete the artifact in <30s if it's wrong.
|
||||
logger.info(
|
||||
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||
tool_name,
|
||||
)
|
||||
return HITLResult(
|
||||
rejected=False, decision_type="auto_approved", params=dict(params)
|
||||
)
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": action_type,
|
||||
|
|
|
|||
|
|
@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
|
||||
# Per-turn correlation id sourced from ``configurable.turn_id`` at
|
||||
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
|
||||
# predate the column. Used by C1's edit-from-arbitrary-position to map
|
||||
# a message back to the LangGraph checkpoint that produced its turn.
|
||||
turn_id = Column(String(64), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
thread = relationship("NewChatThread", back_populates="messages")
|
||||
author = relationship("User")
|
||||
|
|
@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel):
|
|||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
|
||||
# been renamed to ``tool_call_id`` (with a parallel column kept for one
|
||||
# release for back-compat). The real chat-turn id lives in
|
||||
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
|
||||
turn_id = Column(String(64), nullable=True, index=True)
|
||||
tool_call_id = Column(String(64), nullable=True, index=True)
|
||||
chat_turn_id = Column(String(64), nullable=True, index=True)
|
||||
message_id = Column(String(128), nullable=True, index=True)
|
||||
tool_name = Column(String(255), nullable=False, index=True)
|
||||
args = Column(JSONB, nullable=True)
|
||||
|
|
@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel):
|
|||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
||||
# Partial unique index enforces "at most one revert per
|
||||
# original action". Created in migration 137 with
|
||||
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
|
||||
# (the vast majority) are unaffected and NULLs don't collide.
|
||||
Index(
|
||||
"ux_agent_action_log_reverse_of",
|
||||
"reverse_of",
|
||||
unique=True,
|
||||
postgresql_where=text("reverse_of IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel):
|
|||
|
||||
__tablename__ = "document_revisions"
|
||||
|
||||
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||
# hard-delete it describes — without that, ``rm`` would wipe the row
|
||||
# we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
ForeignKey("documents.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
|
|
@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel):
|
|||
|
||||
__tablename__ = "folder_revisions"
|
||||
|
||||
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||
# hard-delete it describes — without that, ``rmdir`` would wipe the
|
||||
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||
folder_id = Column(
|
||||
Integer,
|
||||
ForeignKey("folders.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
ForeignKey("folders.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
|
|||
reverse_of: int | None
|
||||
reverted_by_action_id: int | None
|
||||
is_revert_action: bool
|
||||
# Correlation ids added in migration 135. ``tool_call_id`` is the
|
||||
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
|
||||
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
|
||||
# from ``configurable.turn_id`` (used by the
|
||||
# ``revert-turn/{chat_turn_id}`` endpoint).
|
||||
tool_call_id: str | None = None
|
||||
chat_turn_id: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
|
@ -172,6 +179,8 @@ async def list_thread_actions(
|
|||
reverse_of=row.reverse_of,
|
||||
reverted_by_action_id=revert_map.get(row.id),
|
||||
is_revert_action=row.reverse_of is not None,
|
||||
tool_call_id=row.tool_call_id,
|
||||
chat_turn_id=row.chat_turn_id,
|
||||
created_at=row.created_at,
|
||||
)
|
||||
for row in rows
|
||||
|
|
|
|||
|
|
@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
|
|||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||
5. Idempotent on retries: if the same action is reverted twice the second
|
||||
call returns 409 ``"already reverted"``.
|
||||
|
||||
This module also hosts the per-turn batch endpoint
|
||||
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
|
||||
walks every reversible action emitted during a chat turn in reverse
|
||||
``created_at`` order and reverts each independently. Partial success is the
|
||||
common case — the response always contains a per-action result list and a
|
||||
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
|
||||
whole-batch 4xx.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
|
|
@ -97,6 +108,16 @@ async def revert_agent_action(
|
|||
action=action,
|
||||
requester_user_id=str(user.id) if user is not None else None,
|
||||
)
|
||||
except IntegrityError:
|
||||
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
|
||||
# a concurrent revert. Translate to the existing 409 "already
|
||||
# reverted" contract so racing clients see consistent
|
||||
# behaviour with the pre-flight TOCTOU check above.
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This action has already been reverted.",
|
||||
) from None
|
||||
except Exception as err:
|
||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||
await session.rollback()
|
||||
|
|
@ -105,7 +126,16 @@ async def revert_agent_action(
|
|||
) from err
|
||||
|
||||
if outcome.status == "ok":
|
||||
await session.commit()
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# Race lost on commit (constraint enforced at flush in some
|
||||
# configs but at commit in others — defensive).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This action has already been reverted.",
|
||||
) from None
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": outcome.message,
|
||||
|
|
@ -122,3 +152,357 @@ async def revert_agent_action(
|
|||
raise HTTPException(status_code=501, detail=outcome.message)
|
||||
# not_reversible
|
||||
raise HTTPException(status_code=409, detail=outcome.message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-turn revert batch endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
PerActionStatus = Literal[
|
||||
"reverted",
|
||||
"already_reverted",
|
||||
"not_reversible",
|
||||
"permission_denied",
|
||||
"failed",
|
||||
"skipped",
|
||||
]
|
||||
|
||||
|
||||
class RevertTurnActionResult(BaseModel):
|
||||
"""Per-action outcome inside a ``revert-turn`` batch response."""
|
||||
|
||||
action_id: int
|
||||
tool_name: str
|
||||
status: PerActionStatus
|
||||
message: str | None = None
|
||||
new_action_id: int | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class RevertTurnResponse(BaseModel):
|
||||
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||
|
||||
``status`` is ``"ok"`` only when every reversible row succeeded. Any
|
||||
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
|
||||
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
|
||||
``results`` list — callers should treat that as a no-op.
|
||||
|
||||
Counter invariant:
|
||||
``total == reverted + already_reverted + not_reversible
|
||||
+ permission_denied + failed + skipped``
|
||||
|
||||
Frontend toasts and the ``RevertTurnButton`` summary rely on this
|
||||
invariant to display "X of Y reverted, Z could not be undone" without
|
||||
silently dropping ``permission_denied`` or ``skipped`` rows.
|
||||
"""
|
||||
|
||||
status: Literal["ok", "partial"]
|
||||
chat_turn_id: str
|
||||
total: int
|
||||
reverted: int
|
||||
already_reverted: int
|
||||
not_reversible: int
|
||||
permission_denied: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
results: list[RevertTurnActionResult]
|
||||
|
||||
|
||||
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
|
||||
if outcome.status == "ok":
|
||||
return "reverted"
|
||||
if outcome.status == "permission_denied":
|
||||
return "permission_denied"
|
||||
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
|
||||
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
|
||||
# — they share the same UX (this row cannot be undone) and only the
|
||||
# ``message`` differs.
|
||||
return "not_reversible"
|
||||
|
||||
|
||||
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
|
||||
"""Return the id of an existing successful revert row, if any.
|
||||
|
||||
Single-action variant — kept for the post-IntegrityError lookup
|
||||
path where we already know we lost a race for one specific id.
|
||||
"""
|
||||
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _was_already_reverted_batch(
|
||||
session: AsyncSession, *, action_ids: list[int]
|
||||
) -> dict[int, int]:
|
||||
"""Batch idempotency probe for the revert-turn loop.
|
||||
|
||||
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
|
||||
(one per row in the turn) with a single ``SELECT id, reverse_of
|
||||
WHERE reverse_of IN (:ids)``. The route still iterates rows in
|
||||
reverse-chronological order, but the membership check is O(1) per
|
||||
iteration after this query. For a turn with 30 actions that's 30
|
||||
fewer round-trips through asyncpg + a smaller transaction footprint.
|
||||
|
||||
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
|
||||
keys mean "not yet reverted" — callers should treat them as
|
||||
eligible for revert.
|
||||
"""
|
||||
if not action_ids:
|
||||
return {}
|
||||
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
|
||||
AgentActionLog.reverse_of.in_(action_ids)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return {
|
||||
original_id: revert_id
|
||||
for revert_id, original_id in result.all()
|
||||
if original_id is not None
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
|
||||
response_model=RevertTurnResponse,
|
||||
)
|
||||
async def revert_agent_turn(
|
||||
thread_id: int,
|
||||
chat_turn_id: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> RevertTurnResponse:
|
||||
"""Revert every reversible action emitted during ``chat_turn_id``.
|
||||
|
||||
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
|
||||
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
|
||||
folder) unwind in the right sequence. Each action is reverted in its
|
||||
own SAVEPOINT so a single failure does not poison the batch.
|
||||
|
||||
Partial success is intentional and returned with HTTP 200. Callers
|
||||
must inspect ``results[*].status`` to find rows that need attention.
|
||||
"""
|
||||
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Revert is not available on this deployment yet. The route "
|
||||
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
|
||||
"enable it."
|
||||
),
|
||||
)
|
||||
|
||||
thread = await load_thread(session, thread_id=thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||
|
||||
# Reverse-chronological so the latest mutation in the turn unwinds
|
||||
# first. ``id.desc()`` is the deterministic tiebreaker for actions
|
||||
# written in the same millisecond.
|
||||
rows_stmt = (
|
||||
select(AgentActionLog)
|
||||
.where(
|
||||
AgentActionLog.thread_id == thread_id,
|
||||
AgentActionLog.chat_turn_id == chat_turn_id,
|
||||
)
|
||||
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
|
||||
)
|
||||
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||
|
||||
requester_user_id = str(user.id) if user is not None else None
|
||||
results: list[RevertTurnActionResult] = []
|
||||
# Counters MUST be exhaustive so the response invariant
|
||||
# ``total == sum(counters)`` always holds. Frontend toasts and
|
||||
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
|
||||
counts: dict[str, int] = {
|
||||
"reverted": 0,
|
||||
"already_reverted": 0,
|
||||
"not_reversible": 0,
|
||||
"permission_denied": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
# Single batched idempotency probe replaces the previous per-row
|
||||
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
|
||||
# the original-action ids (skip rows that are themselves
|
||||
# reverts).
|
||||
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||
already_reverted_map = await _was_already_reverted_batch(
|
||||
session, action_ids=eligible_ids
|
||||
)
|
||||
|
||||
for action in rows:
|
||||
# Skip rows that ARE reverts of an earlier action — reverting a
|
||||
# revert is meaningless inside a batch (the user wants to wipe
|
||||
# the original effects, not chase tail).
|
||||
if action.reverse_of is not None:
|
||||
counts["skipped"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="skipped",
|
||||
message="Row is itself a revert action; skipped.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Idempotency: surface "already_reverted" instead of failing.
|
||||
existing_revert_id = already_reverted_map.get(action.id)
|
||||
if existing_revert_id is not None:
|
||||
counts["already_reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not can_revert(
|
||||
requester_user_id=requester_user_id,
|
||||
action=action,
|
||||
is_admin=False,
|
||||
):
|
||||
counts["permission_denied"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="permission_denied",
|
||||
message="You are not allowed to revert this action.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Per-row SAVEPOINT so one failed revert never poisons later
|
||||
# successful ones.
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
outcome = await revert_action(
|
||||
session,
|
||||
action=action,
|
||||
requester_user_id=requester_user_id,
|
||||
)
|
||||
if outcome.status != "ok":
|
||||
raise _OutcomeRollbackError(outcome)
|
||||
except _OutcomeRollbackError as rollback:
|
||||
outcome = rollback.outcome
|
||||
classified = _classify_outcome(outcome)
|
||||
if classified == "permission_denied":
|
||||
counts["permission_denied"] += 1
|
||||
else:
|
||||
counts["not_reversible"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status=classified,
|
||||
message=outcome.message,
|
||||
)
|
||||
)
|
||||
continue
|
||||
except IntegrityError:
|
||||
# Partial unique index caught a concurrent revert that won
|
||||
# the race against our pre-flight ``_was_already_reverted``
|
||||
# SELECT. Look up the winner so
|
||||
# we can surface its ``new_action_id`` to the client.
|
||||
existing_revert_id = await _was_already_reverted(
|
||||
session, action_id=action.id
|
||||
)
|
||||
counts["already_reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
except Exception as err: # pragma: no cover — defensive, logged
|
||||
logger.exception(
|
||||
"Unexpected revert failure inside batch for action_id=%s",
|
||||
action.id,
|
||||
)
|
||||
counts["failed"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="failed",
|
||||
error=str(err) or err.__class__.__name__,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
counts["reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="reverted",
|
||||
message=outcome.message,
|
||||
new_action_id=outcome.new_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Single commit at the end — successful SAVEPOINTs above already
|
||||
# released; failed ones rolled back to their savepoint. No row leaks
|
||||
# across the boundary.
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as err: # pragma: no cover — defensive
|
||||
logger.exception(
|
||||
"Final commit for revert-turn failed (thread=%s turn=%s)",
|
||||
thread_id,
|
||||
chat_turn_id,
|
||||
)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal error while finalising revert-turn batch.",
|
||||
) from err
|
||||
|
||||
has_partial = (
|
||||
counts["failed"] > 0
|
||||
or counts["not_reversible"] > 0
|
||||
or counts["permission_denied"] > 0
|
||||
)
|
||||
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
|
||||
|
||||
return RevertTurnResponse(
|
||||
status=overall_status,
|
||||
chat_turn_id=chat_turn_id,
|
||||
total=len(rows),
|
||||
reverted=counts["reverted"],
|
||||
already_reverted=counts["already_reverted"],
|
||||
not_reversible=counts["not_reversible"],
|
||||
permission_denied=counts["permission_denied"],
|
||||
failed=counts["failed"],
|
||||
skipped=counts["skipped"],
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
class _OutcomeRollbackError(Exception):
|
||||
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
|
||||
|
||||
``revert_action`` writes a new ``agent_action_log`` row only on the
|
||||
happy path, but on the failure paths it sometimes mutates the
|
||||
``DocumentRevision``/``Document`` tables before deciding the action
|
||||
is not reversible. Wrapping each call in ``begin_nested`` and raising
|
||||
this from the failure branch ensures we always discard partial
|
||||
writes for failed rows.
|
||||
"""
|
||||
|
||||
def __init__(self, outcome: RevertOutcome) -> None:
|
||||
self.outcome = outcome
|
||||
super().__init__(outcome.message)
|
||||
|
||||
|
||||
__all__ = ["router"]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -136,6 +137,260 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
turn_id: str,
|
||||
) -> str | None:
|
||||
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
|
||||
|
||||
``checkpoint_tuples`` arrives newest-first from
|
||||
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
|
||||
and remember the most recent checkpoint that does NOT belong to the
|
||||
edited turn. As soon as we cross into the edited turn (a checkpoint
|
||||
whose ``turn_id`` matches), we return the previously-tracked
|
||||
checkpoint — that's the state immediately before ``turn_id`` began.
|
||||
|
||||
The naive "newest-first, return first non-matching" approach is
|
||||
INCORRECT when later turns exist after ``turn_id``: their
|
||||
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
|
||||
returned before the real pre-turn boundary is reached.
|
||||
|
||||
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
|
||||
``configurable`` at write time) rather than ``config["configurable"]``
|
||||
so the lookup is portable across checkpointer implementations.
|
||||
|
||||
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
|
||||
the edited turn is the very first turn of the thread). Callers fall
|
||||
back to the oldest available checkpoint in that case.
|
||||
"""
|
||||
|
||||
last_pre_turn_target: str | None = None
|
||||
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
|
||||
metadata = getattr(cp_tuple, "metadata", None) or {}
|
||||
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
|
||||
if cp_turn_id == turn_id:
|
||||
# Crossed into the edited turn; the previous tracked
|
||||
# checkpoint is the rewind target. May be ``None`` if we hit
|
||||
# the edited turn on the very first iteration.
|
||||
return last_pre_turn_target
|
||||
try:
|
||||
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
return last_pre_turn_target
|
||||
|
||||
|
||||
async def _revert_turns_for_regenerate(
|
||||
*,
|
||||
thread_id: int,
|
||||
chat_turn_ids: list[str],
|
||||
requester_user_id: str,
|
||||
) -> dict:
|
||||
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
|
||||
|
||||
Runs BEFORE the regenerate stream so the frontend can surface
|
||||
partial-rollback feedback alongside the new assistant turn. Each
|
||||
turn's actions are reverted in their own SAVEPOINTs (handled
|
||||
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
|
||||
failure never poisons the batch.
|
||||
|
||||
Sequencing inside the request: revert THEN regenerate. The
|
||||
operation is NOT atomic and partial state IS surfaced — see the
|
||||
plan's "Sequencing inside the request" note.
|
||||
"""
|
||||
|
||||
from app.routes.agent_revert_route import (
|
||||
RevertTurnActionResult,
|
||||
_classify_outcome,
|
||||
_OutcomeRollbackError,
|
||||
_was_already_reverted,
|
||||
_was_already_reverted_batch,
|
||||
)
|
||||
from app.services.revert_service import (
|
||||
can_revert,
|
||||
revert_action,
|
||||
)
|
||||
|
||||
aggregated_results: list[dict] = []
|
||||
# Exhaustive counters keep the response invariant
|
||||
# ``total == sum(counters)`` true for ``data-revert-results``.
|
||||
counts = {
|
||||
"reverted": 0,
|
||||
"already_reverted": 0,
|
||||
"not_reversible": 0,
|
||||
"permission_denied": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
# Local import keeps the route module's existing imports tidy and
|
||||
# avoids a circular dependency at module-load time.
|
||||
from app.db import AgentActionLog as _AgentActionLog
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
for chat_turn_id in chat_turn_ids:
|
||||
rows_stmt = (
|
||||
select(_AgentActionLog)
|
||||
.where(
|
||||
_AgentActionLog.thread_id == thread_id,
|
||||
_AgentActionLog.chat_turn_id == chat_turn_id,
|
||||
)
|
||||
.order_by(
|
||||
_AgentActionLog.created_at.desc(),
|
||||
_AgentActionLog.id.desc(),
|
||||
)
|
||||
)
|
||||
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||
|
||||
# Batch idempotency probe across the turn (single SELECT
|
||||
# instead of one per row).
|
||||
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||
already_reverted_map = await _was_already_reverted_batch(
|
||||
session, action_ids=eligible_ids
|
||||
)
|
||||
|
||||
for action in rows:
|
||||
if action.reverse_of is not None:
|
||||
counts["skipped"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="skipped",
|
||||
message="Row is itself a revert action; skipped.",
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
existing_revert_id = already_reverted_map.get(action.id)
|
||||
if existing_revert_id is not None:
|
||||
counts["already_reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
if not can_revert(
|
||||
requester_user_id=requester_user_id,
|
||||
action=action,
|
||||
is_admin=False,
|
||||
):
|
||||
counts["permission_denied"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="permission_denied",
|
||||
message="You are not allowed to revert this action.",
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
outcome = await revert_action(
|
||||
session,
|
||||
action=action,
|
||||
requester_user_id=requester_user_id,
|
||||
)
|
||||
if outcome.status != "ok":
|
||||
raise _OutcomeRollbackError(outcome)
|
||||
except _OutcomeRollbackError as rollback:
|
||||
outcome = rollback.outcome
|
||||
classified = _classify_outcome(outcome)
|
||||
if classified == "permission_denied":
|
||||
counts["permission_denied"] += 1
|
||||
else:
|
||||
counts["not_reversible"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status=classified,
|
||||
message=outcome.message,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
except IntegrityError:
|
||||
# Concurrent revert won the race against the
|
||||
# pre-flight ``_was_already_reverted`` SELECT.
|
||||
# Surface the winning revert id so the client can
|
||||
# treat this as a successful idempotent op.
|
||||
existing_revert_id = await _was_already_reverted(
|
||||
session, action_id=action.id
|
||||
)
|
||||
counts["already_reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
except Exception as err: # pragma: no cover — defensive
|
||||
_logger.exception(
|
||||
"Unexpected revert failure during regenerate batch "
|
||||
"for action_id=%s",
|
||||
action.id,
|
||||
)
|
||||
counts["failed"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="failed",
|
||||
error=str(err) or err.__class__.__name__,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
counts["reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="reverted",
|
||||
message=outcome.message,
|
||||
new_action_id=outcome.new_action_id,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
"[regenerate-revert] Final commit failed; rolling back batch."
|
||||
)
|
||||
await session.rollback()
|
||||
|
||||
has_partial = (
|
||||
counts["failed"] > 0
|
||||
or counts["not_reversible"] > 0
|
||||
or counts["permission_denied"] > 0
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "partial" if has_partial else "ok",
|
||||
"chat_turn_ids": chat_turn_ids,
|
||||
"total": len(aggregated_results),
|
||||
"reverted": counts["reverted"],
|
||||
"already_reverted": counts["already_reverted"],
|
||||
"not_reversible": counts["not_reversible"],
|
||||
"permission_denied": counts["permission_denied"],
|
||||
"failed": counts["failed"],
|
||||
"skipped": counts["skipped"],
|
||||
"results": aggregated_results,
|
||||
}
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
|
|
@ -574,6 +829,7 @@ async def get_thread_messages(
|
|||
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
||||
if msg.token_usage
|
||||
else None,
|
||||
turn_id=msg.turn_id,
|
||||
)
|
||||
for msg in db_messages
|
||||
]
|
||||
|
|
@ -1006,12 +1262,24 @@ async def append_message(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Create message
|
||||
# Create message. ``turn_id`` is the per-turn correlation id from
|
||||
# ``configurable.turn_id`` (added in migration 136) — when the
|
||||
# client streams it back to ``appendMessage``, we persist it so
|
||||
# C1's edit-from-arbitrary-position can later map this message
|
||||
# back to the LangGraph checkpoint that produced its turn.
|
||||
raw_turn_id = raw_body.get("turn_id")
|
||||
turn_id_value = (
|
||||
str(raw_turn_id).strip()
|
||||
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
db_message = NewChatMessage(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user.id,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
|
|
@ -1050,6 +1318,7 @@ async def append_message(
|
|||
created_at=db_message.created_at,
|
||||
author_id=db_message.author_id,
|
||||
token_usage=None,
|
||||
turn_id=db_message.turn_id,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -1373,43 +1642,123 @@ async def regenerate_response(
|
|||
user_query_to_use = request.user_query
|
||||
regenerate_image_urls: list[str] = []
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage
|
||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||
# Access the checkpoint's channel_values which contains "messages"
|
||||
checkpoint_data = cp_tuple.checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
# ---------------------------------------------------------------
|
||||
# Edit-from-arbitrary-position. When the client passes
|
||||
# ``from_message_id`` we look up its persisted ``turn_id`` (added
|
||||
# in migration 136) and pick the checkpoint immediately before
|
||||
# that turn started.
|
||||
#
|
||||
# Legacy graceful-degradation contract:
|
||||
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
|
||||
# Returning 400 in that case is the wrong UX — the user is
|
||||
# editing an old message in an existing thread and just wants
|
||||
# it to work. We instead skip the checkpoint rewind (the
|
||||
# stream falls back to the latest state) and skip the revert
|
||||
# pass (no chat_turn_id available to walk). Deletion still
|
||||
# uses ``created_at``, so the messages-after-cursor slice is
|
||||
# correct on both legacy and post-136 rows.
|
||||
# ---------------------------------------------------------------
|
||||
from_message_turn_id: str | None = None
|
||||
from_message_created_at: datetime | None = None
|
||||
legacy_from_message: bool = False
|
||||
if request.from_message_id is not None:
|
||||
from_msg_row = await session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.id == request.from_message_id,
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
)
|
||||
)
|
||||
from_msg = from_msg_row.scalars().first()
|
||||
if from_msg is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="from_message_id not found in this thread.",
|
||||
)
|
||||
from_message_created_at = from_msg.created_at
|
||||
if not from_msg.turn_id:
|
||||
# Legacy row — surface the degradation in logs but let
|
||||
# the request proceed with the slice-based delete and a
|
||||
# cold-start checkpoint.
|
||||
legacy_from_message = True
|
||||
_logger.warning(
|
||||
"[regenerate] from_message_id=%s on thread=%s has no "
|
||||
"turn_id (legacy row pre-migration-136). Falling back "
|
||||
"to slice-based delete without checkpoint rewind. "
|
||||
"revert_actions=%s will be ignored.",
|
||||
request.from_message_id,
|
||||
thread_id,
|
||||
request.revert_actions,
|
||||
)
|
||||
else:
|
||||
from_message_turn_id = from_msg.turn_id
|
||||
|
||||
if state_messages:
|
||||
last_msg = state_messages[-1]
|
||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||
# This means we're at a state before the user's last message
|
||||
if not isinstance(last_msg, HumanMessage):
|
||||
# If no new user_query provided (reload), extract from a later checkpoint
|
||||
if user_query_to_use is None and i > 0:
|
||||
# Get the user query from a more recent checkpoint
|
||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||
prev_channel_values = prev_checkpoint_data.get(
|
||||
"channel_values", {}
|
||||
)
|
||||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
q, imgs = split_langchain_human_content(msg.content)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
if user_query_to_use is not None and (
|
||||
str(user_query_to_use).strip() or regenerate_image_urls
|
||||
):
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
# Walk oldest-to-newest and pick the LAST checkpoint whose
|
||||
# ``turn_id`` differs from the edited turn — that's the state
|
||||
# immediately before this turn started running. We read from
|
||||
# ``metadata`` (the durable surface) rather than
|
||||
# ``config["configurable"]`` so the lookup works across
|
||||
# checkpointer implementations.
|
||||
target_checkpoint_id = _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples,
|
||||
turn_id=from_message_turn_id,
|
||||
)
|
||||
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
|
||||
# Fall back to the oldest checkpoint — better than
|
||||
# 400ing when the agent didn't checkpoint pre-turn
|
||||
# (e.g. very first turn of the thread).
|
||||
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
break
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage.
|
||||
# We enter this branch when:
|
||||
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
|
||||
# * the client pinned ``from_message_id`` but the row is a
|
||||
# legacy pre-migration-136 row with no ``turn_id`` (we
|
||||
# downgraded to the same heuristic as a regular reload).
|
||||
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
|
||||
# — that's the C1 happy path and the heuristic below would just
|
||||
# re-derive a worse target.
|
||||
if request.from_message_id is None or legacy_from_message:
|
||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||
# Access the checkpoint's channel_values which contains "messages"
|
||||
checkpoint_data = cp_tuple.checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
|
||||
if state_messages:
|
||||
last_msg = state_messages[-1]
|
||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||
# This means we're at a state before the user's last message
|
||||
if not isinstance(last_msg, HumanMessage):
|
||||
# If no new user_query provided (reload), extract from a later checkpoint
|
||||
if user_query_to_use is None and i > 0:
|
||||
# Get the user query from a more recent checkpoint
|
||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||
prev_channel_values = prev_checkpoint_data.get(
|
||||
"channel_values", {}
|
||||
)
|
||||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
q, imgs = split_langchain_human_content(
|
||||
msg.content
|
||||
)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
if user_query_to_use is not None and (
|
||||
str(user_query_to_use).strip()
|
||||
or regenerate_image_urls
|
||||
):
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
break
|
||||
|
||||
# If we couldn't find a good checkpoint, try alternative approaches
|
||||
if target_checkpoint_id is None and checkpoint_tuples:
|
||||
|
|
@ -1472,18 +1821,51 @@ async def regenerate_response(
|
|||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||
)
|
||||
|
||||
# Get the last two messages to delete AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(2)
|
||||
)
|
||||
# Get the messages to delete AFTER streaming succeeds.
|
||||
# This prevents data loss if streaming fails.
|
||||
#
|
||||
# When ``from_message_id`` is set we slice from that message
|
||||
# forward (using ``created_at`` so we also catch any tool/system
|
||||
# messages persisted into the same turn). Otherwise
|
||||
# we keep the legacy "last 2 messages" rewind.
|
||||
if request.from_message_id is not None and from_message_created_at is not None:
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.created_at >= from_message_created_at,
|
||||
)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
)
|
||||
else:
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(2)
|
||||
)
|
||||
messages_to_delete = list(last_messages_result.scalars().all())
|
||||
|
||||
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
||||
|
||||
# When revert_actions is requested, collect the set of
|
||||
# ``chat_turn_id``s present in the slice we're about to delete.
|
||||
# Each one will be reverted (best-effort) BEFORE the regenerate
|
||||
# stream begins. Legacy rows have ``turn_id=None`` and silently
|
||||
# contribute nothing — we already logged the degradation above.
|
||||
revert_turn_ids: list[str] = []
|
||||
if (
|
||||
request.revert_actions
|
||||
and request.from_message_id is not None
|
||||
and not legacy_from_message
|
||||
):
|
||||
seen_turns: set[str] = set()
|
||||
for msg in messages_to_delete:
|
||||
tid = msg.turn_id
|
||||
if tid and tid not in seen_turns:
|
||||
seen_turns.add(tid)
|
||||
revert_turn_ids.append(tid)
|
||||
|
||||
# Get search space for LLM config
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
@ -1507,6 +1889,24 @@ async def regenerate_response(
|
|||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||
async def stream_with_cleanup():
|
||||
streaming_completed = False
|
||||
# Best-effort revert pass BEFORE the regenerate stream begins.
|
||||
# Each turn is reverted independently (per-row SAVEPOINTs
|
||||
# inside the route helper) and the per-action results are surfaced
|
||||
# on a single ``data-revert-results`` SSE event so the frontend
|
||||
# can render any failed rows alongside the new turn. Failures here
|
||||
# do NOT abort the regeneration — partial rollback is documented
|
||||
# behaviour.
|
||||
if revert_turn_ids:
|
||||
revert_results = await _revert_turns_for_regenerate(
|
||||
thread_id=thread_id,
|
||||
chat_turn_ids=revert_turn_ids,
|
||||
requester_user_id=str(user.id),
|
||||
)
|
||||
envelope = {
|
||||
"type": "data-revert-results",
|
||||
"data": revert_results,
|
||||
}
|
||||
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||
try:
|
||||
async for chunk in stream_new_chat(
|
||||
user_query=str(user_query_to_use),
|
||||
|
|
|
|||
|
|
@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
|||
author_display_name: str | None = None
|
||||
author_avatar_url: str | None = None
|
||||
token_usage: TokenUsageSummary | None = None
|
||||
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
|
||||
# ``configurable.turn_id`` at streaming time. Nullable because
|
||||
# legacy rows predate the column; clients should treat NULL as
|
||||
# "edit-from-this-message is unavailable".
|
||||
turn_id: str | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
|
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
|
|||
|
||||
For edit, optional user_images (when not None) replaces image URLs resolved from
|
||||
checkpoint/DB so the client can send the full user turn (text and/or images).
|
||||
|
||||
Edit-from-arbitrary-position. When ``from_message_id`` is provided
|
||||
the route slices conversation history starting at that message (instead of
|
||||
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
|
||||
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
|
||||
optionally reverts every reversible action emitted in turns at or after
|
||||
``from_message_id``. The revert step is best-effort and runs BEFORE the
|
||||
regenerate stream — partial failures are surfaced via SSE
|
||||
``data-revert-results`` and do not abort the regeneration.
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
|
|
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
|
|||
default=None,
|
||||
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
||||
)
|
||||
from_message_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Message id to rewind to. When set, history is sliced "
|
||||
"from this message forward and the LangGraph checkpoint is "
|
||||
"rewound to the state immediately preceding this turn. Legacy "
|
||||
"rows that predate migration 136 have ``turn_id=None`` and "
|
||||
"still process — the route logs a warning, skips the "
|
||||
"checkpoint rewind, and ignores ``revert_actions`` (no "
|
||||
"chat_turn_id available to walk)."
|
||||
),
|
||||
)
|
||||
revert_actions: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"When true, every reversible action emitted at or "
|
||||
"after ``from_message_id`` is reverted before the regenerate "
|
||||
"stream begins. Per-action results are surfaced via the "
|
||||
"``data-revert-results`` SSE event. Partial failures DO NOT "
|
||||
"abort the regeneration."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_regenerate_user_images(self) -> Self:
|
||||
|
|
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
|
|||
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_revert_actions_requires_from_message(self) -> Self:
|
||||
if self.revert_actions and self.from_message_id is None:
|
||||
raise ValueError(
|
||||
"revert_actions requires from_message_id; specify which message to rewind to"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Agent Tools Schemas
|
||||
|
|
|
|||
|
|
@ -584,13 +584,33 @@ class VercelStreamingService:
|
|||
# Tool Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
|
||||
def format_tool_input_start(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the start of tool input streaming.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique tool call identifier
|
||||
tool_name: The name of the tool being called
|
||||
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||
path: when the provider streams ``tool_call_chunks`` we
|
||||
register the ``index`` and reuse the lc-id as the card
|
||||
id so live ``tool-input-delta`` events can be routed
|
||||
without a downstream join). Either way, the same id is
|
||||
preserved across ``tool-input-start`` / ``-delta`` /
|
||||
``-available`` / ``tool-output-available`` for one call.
|
||||
tool_name: The name of the tool being called.
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id``. When set, surfaces as
|
||||
``langchainToolCallId`` so the frontend can join this card
|
||||
to the action-log row written by ``ActionLogMiddleware``.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input start part
|
||||
|
|
@ -598,13 +618,14 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||
"""
|
||||
|
|
@ -629,7 +650,12 @@ class VercelStreamingService:
|
|||
)
|
||||
|
||||
def format_tool_input_available(
|
||||
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the completion of tool input.
|
||||
|
|
@ -638,6 +664,8 @@ class VercelStreamingService:
|
|||
tool_call_id: The tool call identifier
|
||||
tool_name: The name of the tool
|
||||
input_data: The complete tool input parameters
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id`` (see ``format_tool_input_start``).
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input available part
|
||||
|
|
@ -645,22 +673,34 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
|
||||
def format_tool_output_available(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format tool execution output.
|
||||
|
||||
Args:
|
||||
tool_call_id: The tool call identifier
|
||||
output: The tool execution result
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
|
||||
When set, the frontend can backfill any card whose
|
||||
``langchainToolCallId`` was not yet known at
|
||||
``tool-input-start`` time.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool output available part
|
||||
|
|
@ -668,13 +708,14 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
# Step Parts
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
|
|||
|
||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||
written before the original mutation.
|
||||
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
|
||||
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
|
||||
that was created; everything else is an in-place restore.
|
||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||
the inverse tool through the agent's normal permission stack (NOT
|
||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||
|
|
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
|
|||
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||
``reverse_of=<original_action_id>`` and the requesting user's
|
||||
``user_id``, preserving an auditable chain.
|
||||
|
||||
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
|
||||
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
|
||||
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
|
||||
same trap waiting to happen).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -25,17 +32,31 @@ from __future__ import annotations
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
safe_filename,
|
||||
safe_folder_segment,
|
||||
)
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentRevision,
|
||||
DocumentType,
|
||||
Folder,
|
||||
FolderRevision,
|
||||
NewChatThread,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
embed_texts,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -110,14 +131,244 @@ def can_revert(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revert paths
|
||||
# Helper: reconstruct virtual path from a snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _virtual_path_from_snapshot(
|
||||
session: AsyncSession,
|
||||
revision: DocumentRevision,
|
||||
) -> str | None:
|
||||
"""Reconstruct the virtual_path the document was at before mutation.
|
||||
|
||||
Preference order:
|
||||
1. ``metadata_before["virtual_path"]`` — written by every snapshot
|
||||
helper since this PR.
|
||||
2. Compose ``"<folder_path>/<title_before>"`` from
|
||||
``folder_id_before`` + ``title_before``. Walks the folder chain via
|
||||
``parent_id``.
|
||||
"""
|
||||
metadata = revision.metadata_before or {}
|
||||
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
|
||||
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
|
||||
return candidate
|
||||
|
||||
title = revision.title_before
|
||||
if not isinstance(title, str) or not title:
|
||||
return None
|
||||
|
||||
parts: list[str] = []
|
||||
cursor: int | None = revision.folder_id_before
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
folder = await session.get(Folder, cursor)
|
||||
if folder is None:
|
||||
return None
|
||||
parts.append(safe_folder_segment(str(folder.name or "")))
|
||||
cursor = folder.parent_id
|
||||
parts.reverse()
|
||||
|
||||
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
filename = safe_filename(title)
|
||||
return f"{base}/{filename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document revision restore (write/edit/move/rm)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_field(target: Any, field: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
setattr(target, field, value)
|
||||
|
||||
|
||||
async def _restore_in_place_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Apply an in-place restore to an existing :class:`Document`."""
|
||||
if revision.document_id is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Original document was hard-deleted; in-place restore is not possible."
|
||||
),
|
||||
)
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
_set_field(doc, "content", revision.content_before)
|
||||
_set_field(doc, "source_markdown", revision.content_before)
|
||||
_set_field(doc, "title", revision.title_before)
|
||||
_set_field(doc, "folder_id", revision.folder_id_before)
|
||||
metadata_before = revision.metadata_before or {}
|
||||
if isinstance(metadata_before, dict) and metadata_before:
|
||||
doc.document_metadata = dict(metadata_before)
|
||||
|
||||
if isinstance(revision.content_before, str):
|
||||
doc.content_hash = generate_content_hash(
|
||||
revision.content_before, doc.search_space_id
|
||||
)
|
||||
|
||||
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||
if virtual_path:
|
||||
doc.unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
doc.search_space_id,
|
||||
)
|
||||
|
||||
chunks_before = revision.chunks_before
|
||||
if isinstance(chunks_before, list):
|
||||
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
|
||||
chunk_texts = [
|
||||
str(c.get("content"))
|
||||
for c in chunks_before
|
||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||
]
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(
|
||||
chunk_texts, chunk_embeddings, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
if isinstance(revision.content_before, str):
|
||||
doc.embedding = embed_texts([revision.content_before])[0]
|
||||
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
|
||||
|
||||
async def _reinsert_document_from_revision(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
|
||||
if not isinstance(revision.title_before, str) or not revision.title_before:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks title_before; cannot recreate document.",
|
||||
)
|
||||
if not isinstance(revision.content_before, str):
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks content_before; cannot recreate document.",
|
||||
)
|
||||
|
||||
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||
if not virtual_path:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message=(
|
||||
"Snapshot is missing both metadata_before['virtual_path'] AND "
|
||||
"a resolvable (folder_id_before, title_before) pair."
|
||||
),
|
||||
)
|
||||
|
||||
search_space_id = revision.search_space_id
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
collision = await session.execute(
|
||||
select(Document.id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_identifier_hash,
|
||||
)
|
||||
)
|
||||
if collision.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
f"A document already exists at '{virtual_path}'; revert would "
|
||||
"collide. Move the live doc out of the way first."
|
||||
),
|
||||
)
|
||||
|
||||
metadata = revision.metadata_before or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
metadata = dict(metadata)
|
||||
metadata["virtual_path"] = virtual_path
|
||||
|
||||
content = revision.content_before
|
||||
new_doc = Document(
|
||||
title=revision.title_before,
|
||||
document_type=DocumentType.NOTE,
|
||||
document_metadata=metadata,
|
||||
content=content,
|
||||
content_hash=generate_content_hash(content, search_space_id),
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=content,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=revision.folder_id_before,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(new_doc)
|
||||
await session.flush()
|
||||
|
||||
new_doc.embedding = embed_texts([content])[0]
|
||||
chunk_texts = []
|
||||
chunks_before = revision.chunks_before
|
||||
if isinstance(chunks_before, list):
|
||||
chunk_texts = [
|
||||
str(c.get("content"))
|
||||
for c in chunks_before
|
||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||
]
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Repoint the snapshot at the recreated row so a follow-up revert of
|
||||
# the same row works as expected.
|
||||
revision.document_id = new_doc.id
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
|
||||
)
|
||||
|
||||
|
||||
async def _delete_created_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
|
||||
if revision.document_id is None:
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="No live row to delete (already removed elsewhere).",
|
||||
)
|
||||
await session.execute(delete(Document).where(Document.id == revision.document_id))
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="Deleted the document that was created by this action.",
|
||||
)
|
||||
|
||||
|
||||
async def _restore_document_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
||||
"""Dispatch document-level revert based on ``action.tool_name``."""
|
||||
stmt = (
|
||||
select(DocumentRevision)
|
||||
.where(DocumentRevision.agent_action_id == action.id)
|
||||
|
|
@ -132,23 +383,111 @@ async def _restore_document_revision(
|
|||
message="No document_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Document # late import to avoid cycles at module load
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
if tool_name == "rm":
|
||||
return await _reinsert_document_from_revision(session, revision=revision)
|
||||
|
||||
if tool_name == "write_file" and revision.content_before is None:
|
||||
return await _delete_created_document(session, revision=revision)
|
||||
|
||||
return await _restore_in_place_document(session, revision=revision)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Folder revision restore (mkdir/rmdir/rename/move)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _restore_in_place_folder(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if revision.folder_id is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
message="Original folder was hard-deleted; in-place restore is impossible.",
|
||||
)
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
_set_field(folder, "name", revision.name_before)
|
||||
_set_field(folder, "parent_id", revision.parent_id_before)
|
||||
_set_field(folder, "position", revision.position_before)
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
|
||||
|
||||
async def _reinsert_folder_from_revision(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if not isinstance(revision.name_before, str) or not revision.name_before:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks name_before; cannot recreate folder.",
|
||||
)
|
||||
new_folder = Folder(
|
||||
name=revision.name_before,
|
||||
parent_id=revision.parent_id_before,
|
||||
position=revision.position_before,
|
||||
search_space_id=revision.search_space_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(new_folder)
|
||||
await session.flush()
|
||||
revision.folder_id = new_folder.id
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
|
||||
)
|
||||
|
||||
|
||||
async def _delete_created_folder(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if revision.folder_id is None:
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="No live folder row to delete (already removed elsewhere).",
|
||||
)
|
||||
folder_id = revision.folder_id
|
||||
|
||||
has_doc = await session.execute(
|
||||
select(Document.id).where(Document.folder_id == folder_id).limit(1)
|
||||
)
|
||||
if has_doc.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Folder is no longer empty (documents have been added since "
|
||||
"mkdir); cannot revert."
|
||||
),
|
||||
)
|
||||
has_child = await session.execute(
|
||||
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
|
||||
)
|
||||
if has_child.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Folder is no longer empty (sub-folders have been added "
|
||||
"since mkdir); cannot revert."
|
||||
),
|
||||
)
|
||||
|
||||
if revision.content_before is not None:
|
||||
doc.content = revision.content_before
|
||||
if revision.title_before is not None:
|
||||
doc.title = revision.title_before
|
||||
if revision.folder_id_before is not None:
|
||||
doc.folder_id = revision.folder_id_before
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
await session.execute(delete(Folder).where(Folder.id == folder_id))
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="Deleted the folder that was created by this action.",
|
||||
)
|
||||
|
||||
|
||||
async def _restore_folder_revision(
|
||||
|
|
@ -168,41 +507,44 @@ async def _restore_folder_revision(
|
|||
message="No folder_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Folder
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
if tool_name == "rmdir":
|
||||
return await _reinsert_folder_from_revision(session, revision=revision)
|
||||
|
||||
if revision.name_before is not None:
|
||||
folder.name = revision.name_before
|
||||
if revision.parent_id_before is not None:
|
||||
folder.parent_id = revision.parent_id_before
|
||||
if revision.position_before is not None:
|
||||
folder.position = revision.position_before
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
if tool_name == "mkdir":
|
||||
return await _delete_created_folder(session, revision=revision)
|
||||
|
||||
return await _restore_in_place_folder(session, revision=revision)
|
||||
|
||||
|
||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
||||
# this module's control flow.
|
||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
|
||||
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
|
||||
# ``delete_note``/``delete_folder``.
|
||||
|
||||
_DOC_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"move_file",
|
||||
"rm",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
}
|
||||
)
|
||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"mkdir",
|
||||
"move_file",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
_FOLDER_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"mkdir",
|
||||
"rmdir",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,9 +562,9 @@ async def revert_action(
|
|||
"""
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
||||
if tool_name in _DOC_TOOLS:
|
||||
outcome = await _restore_document_revision(session, action=action)
|
||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
||||
elif tool_name in _FOLDER_TOOLS:
|
||||
outcome = await _restore_folder_revision(session, action=action)
|
||||
elif action.reverse_descriptor:
|
||||
# Connector-owned reversibles run through the normal permission
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
|
|||
from app.agents.multi_agent_chat.integration import create_multi_agent_chat
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
|
|
@ -72,6 +73,91 @@ _perf_log = get_perf_logger()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||
|
||||
Returns a dict with three keys:
|
||||
|
||||
* ``text`` — concatenated string content (empty string if the chunk
|
||||
contributes none).
|
||||
* ``reasoning`` — concatenated reasoning content (empty string if the
|
||||
chunk contributes none).
|
||||
* ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk``
|
||||
dicts surfaced from either the typed-block list or the
|
||||
``tool_call_chunks`` attribute.
|
||||
|
||||
Background
|
||||
----------
|
||||
``AIMessageChunk.content`` can be:
|
||||
|
||||
* a ``str`` (most providers), or
|
||||
* a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
|
||||
'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
|
||||
Anthropic, Bedrock, and several reasoning configurations.
|
||||
|
||||
Reasoning may also live under
|
||||
``chunk.additional_kwargs['reasoning_content']`` (some providers
|
||||
surface it that way instead of as a typed block). Tool-call chunks
|
||||
may live under ``chunk.tool_call_chunks`` even when ``content`` is a
|
||||
plain string.
|
||||
|
||||
Earlier versions only handled the ``isinstance(content, str)`` branch
|
||||
and silently dropped reasoning blocks + tool-call chunks emitted by
|
||||
LangChain ``AIMessageChunk``s.
|
||||
"""
|
||||
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
|
||||
if chunk is None:
|
||||
return out
|
||||
|
||||
content = getattr(chunk, "content", None)
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
out["text"] = content
|
||||
elif isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
if isinstance(value, str) and value:
|
||||
text_parts.append(value)
|
||||
elif block_type == "reasoning":
|
||||
value = (
|
||||
block.get("reasoning")
|
||||
or block.get("text")
|
||||
or block.get("content")
|
||||
or ""
|
||||
)
|
||||
if isinstance(value, str) and value:
|
||||
reasoning_parts.append(value)
|
||||
elif block_type in ("tool_call_chunk", "tool_use"):
|
||||
out["tool_call_chunks"].append(block)
|
||||
if text_parts:
|
||||
out["text"] = "".join(text_parts)
|
||||
if reasoning_parts:
|
||||
out["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
additional = getattr(chunk, "additional_kwargs", None) or {}
|
||||
if isinstance(additional, dict):
|
||||
extra_reasoning = additional.get("reasoning_content")
|
||||
if isinstance(extra_reasoning, str) and extra_reasoning:
|
||||
existing = out["reasoning"]
|
||||
out["reasoning"] = (
|
||||
(existing + extra_reasoning) if existing else extra_reasoning
|
||||
)
|
||||
|
||||
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
|
||||
if isinstance(extra_tool_chunks, list):
|
||||
for tcc in extra_tool_chunks:
|
||||
if isinstance(tcc, dict):
|
||||
out["tool_call_chunks"].append(tcc)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
) -> str:
|
||||
|
|
@ -254,6 +340,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
pending_tool_call_chunks: list[dict[str, Any]],
|
||||
tool_name: str,
|
||||
run_id: str,
|
||||
lc_tool_call_id_by_run: dict[str, str],
|
||||
) -> str | None:
|
||||
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
|
||||
|
||||
Pure extract of the legacy in-line match used at ``on_tool_start`` for
|
||||
parity_v2-OFF and unmatched (chunk path didn't register an index for
|
||||
this call) tools. Pops the next id-bearing chunk whose ``name``
|
||||
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
|
||||
returns its id. Mutates ``pending_tool_call_chunks`` and
|
||||
``lc_tool_call_id_by_run`` in place.
|
||||
"""
|
||||
matched_idx: int | None = None
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
return None
|
||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||
candidate = matched.get("id")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = candidate
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
|
|
@ -268,6 +390,7 @@ async def _stream_agent_events(
|
|||
fallback_commit_search_space_id: int | None = None,
|
||||
fallback_commit_created_by_id: str | None = None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
fallback_commit_thread_id: int | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Shared async generator that streams and formats astream_events from the agent.
|
||||
|
||||
|
|
@ -300,6 +423,59 @@ async def _stream_agent_events(
|
|||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||
called_update_memory: bool = False
|
||||
|
||||
# Reasoning-block streaming. We open a reasoning block on the
|
||||
# first reasoning delta of a step, append deltas as they arrive, and
|
||||
# close it when text starts (the model has switched to writing its
|
||||
# answer) or ``on_chat_model_end`` fires for the model node. Reuses
|
||||
# the same Vercel format-helpers as text-start/delta/end.
|
||||
current_reasoning_id: str | None = None
|
||||
|
||||
# Streaming-parity v2 feature flag. When OFF we keep the legacy
|
||||
# shape: str-only content, no reasoning blocks, no
|
||||
# ``langchainToolCallId`` propagation. The schema migrations
|
||||
# (135 / 136) ship unconditionally because they're forward-compatible.
|
||||
parity_v2 = bool(get_flags().enable_stream_parity_v2)
|
||||
|
||||
# Best-effort attach of LangChain ``tool_call_id`` to the synthetic
|
||||
# ``call_<run_id>`` card id we already emit. We accumulate
|
||||
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
||||
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
||||
# authoritative id is later filled in at ``on_tool_end`` from
|
||||
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
|
||||
# this list for chunks that already registered into ``index_to_meta``
|
||||
# below — so this list is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path only and never re-pops a chunk we already streamed.
|
||||
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
|
||||
# is keyed by the chunk's ``index`` field — LangChain
|
||||
# ``ToolCallChunk``s for the same call share an index but only the
|
||||
# first chunk carries id+name (subsequent ones are id=None,
|
||||
# name=None, args="<delta>"). We register an index when both id and
|
||||
# name are observed on a chunk (per ToolCallChunk semantics they
|
||||
# arrive together on the first chunk), then route every later chunk
|
||||
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
|
||||
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
|
||||
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
|
||||
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
|
||||
# the same card.
|
||||
index_to_meta: dict[int, dict[str, str]] = {}
|
||||
ui_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
||||
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
||||
# ``format_tool_output_available`` call automatically carries the
|
||||
# authoritative id without duplicating the kwarg at every call site.
|
||||
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
||||
|
||||
def _emit_tool_output(call_id: str, output: Any) -> str:
|
||||
return streaming_service.format_tool_output_available(
|
||||
call_id,
|
||||
output,
|
||||
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
||||
)
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
|
|
@ -328,22 +504,119 @@ async def _stream_agent_events(
|
|||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
if current_text_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(current_text_id, content)
|
||||
accumulated_text += content
|
||||
if not chunk:
|
||||
continue
|
||||
parts = _extract_chunk_parts(chunk)
|
||||
|
||||
reasoning_delta = parts["reasoning"]
|
||||
text_delta = parts["text"]
|
||||
|
||||
# Reasoning streaming. Open a reasoning block on first
|
||||
# delta; append every subsequent delta until text begins.
|
||||
# When text starts we close the reasoning block first so the
|
||||
# frontend sees the natural hand-off. Gated behind the
|
||||
# parity-v2 flag so legacy deployments keep today's shape.
|
||||
if parity_v2 and reasoning_delta:
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
||||
yield streaming_service.format_reasoning_delta(
|
||||
current_reasoning_id, reasoning_delta
|
||||
)
|
||||
|
||||
if text_delta:
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
||||
current_reasoning_id = None
|
||||
if current_text_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||
accumulated_text += text_delta
|
||||
|
||||
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||
# processing so chunks containing both stay in their natural
|
||||
# wire order (text → text-end → tool-input-start). Active
|
||||
# text/reasoning are closed inside the registration branch
|
||||
# before ``tool-input-start`` so the frontend sees a clean
|
||||
# part boundary even when providers interleave.
|
||||
if parity_v2 and parts["tool_call_chunks"]:
|
||||
for tcc in parts["tool_call_chunks"]:
|
||||
idx = tcc.get("index")
|
||||
|
||||
# Register this index when we first see id+name
|
||||
# TOGETHER. Per LangChain ToolCallChunk semantics the
|
||||
# first chunk for a tool call carries both fields
|
||||
# together; later chunks have id=None, name=None and
|
||||
# only ``args``. Requiring BOTH keeps wire
|
||||
# ``tool-input-start`` always carrying a real
|
||||
# toolName (assistant-ui's typed tool-part dispatch
|
||||
# keys off it).
|
||||
if idx is not None and idx not in index_to_meta:
|
||||
lc_id = tcc.get("id")
|
||||
name = tcc.get("name")
|
||||
if lc_id and name:
|
||||
ui_id = lc_id
|
||||
|
||||
# Close active text/reasoning so wire
|
||||
# ordering stays clean even on providers
|
||||
# that interleave text and tool-call chunks
|
||||
# within the same stream window.
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(
|
||||
current_reasoning_id
|
||||
)
|
||||
current_reasoning_id = None
|
||||
|
||||
index_to_meta[idx] = {
|
||||
"ui_id": ui_id,
|
||||
"lc_id": lc_id,
|
||||
"name": name,
|
||||
}
|
||||
yield streaming_service.format_tool_input_start(
|
||||
ui_id,
|
||||
name,
|
||||
langchain_tool_call_id=lc_id,
|
||||
)
|
||||
|
||||
# Emit args delta for any chunk at a registered
|
||||
# index (including idless continuations). Once an
|
||||
# index is owned by ``index_to_meta`` we DO NOT
|
||||
# append to ``pending_tool_call_chunks`` — that list
|
||||
# is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path so it never re-pops chunks already
|
||||
# consumed here (skip-append).
|
||||
meta = index_to_meta.get(idx) if idx is not None else None
|
||||
if meta:
|
||||
args_chunk = tcc.get("args") or ""
|
||||
if args_chunk:
|
||||
yield streaming_service.format_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
else:
|
||||
pending_tool_call_chunks.append(tcc)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
|
|
@ -463,6 +736,95 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rm":
|
||||
rm_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
|
||||
last_active_step_title = "Deleting file"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting file",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rmdir":
|
||||
rmdir_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = (
|
||||
rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:]
|
||||
)
|
||||
last_active_step_title = "Deleting folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting folder",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "mkdir":
|
||||
mkdir_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = (
|
||||
mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:]
|
||||
)
|
||||
last_active_step_title = "Creating folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Creating folder",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "move_file":
|
||||
src = (
|
||||
tool_input.get("source_path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else ""
|
||||
)
|
||||
dst = (
|
||||
tool_input.get("destination_path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else ""
|
||||
)
|
||||
display_src = src if len(src) <= 60 else "…" + src[-57:]
|
||||
display_dst = dst if len(dst) <= 60 else "…" + dst[-57:]
|
||||
last_active_step_title = "Moving file"
|
||||
last_active_step_items = (
|
||||
[f"{display_src} → {display_dst}"] if src or dst else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Moving file",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_todos":
|
||||
todos = (
|
||||
tool_input.get("todos", []) if isinstance(tool_input, dict) else []
|
||||
)
|
||||
todo_count = len(todos) if isinstance(todos, list) else 0
|
||||
last_active_step_title = "Planning tasks"
|
||||
last_active_step_items = (
|
||||
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
|
||||
if todo_count
|
||||
else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Planning tasks",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "save_document":
|
||||
doc_title = (
|
||||
tool_input.get("title", "")
|
||||
|
|
@ -570,7 +932,15 @@ async def _stream_agent_events(
|
|||
items=last_active_step_items,
|
||||
)
|
||||
else:
|
||||
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
|
||||
# Fallback for tools without a curated thinking-step title
|
||||
# (typically connector tools, MCP-registered tools, or
|
||||
# newly added tools that haven't been wired up here yet).
|
||||
# Render the snake_cased name as a sentence-cased phrase
|
||||
# so non-technical users see e.g. "Send gmail email"
|
||||
# rather than the raw identifier "send_gmail_email".
|
||||
last_active_step_title = (
|
||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
last_active_step_items = []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
|
|
@ -578,12 +948,65 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
)
|
||||
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
||||
# Resolve the card identity. If the chunk-emission loop
|
||||
# already registered an ``index`` for this tool call (parity_v2
|
||||
# path), reuse the same ui_id so the card sees:
|
||||
# tool-input-start → deltas… → tool-input-available →
|
||||
# tool-output-available all keyed by lc_id. Otherwise fall
|
||||
# back to the synthetic ``call_<run_id>`` id and the legacy
|
||||
# best-effort match against ``pending_tool_call_chunks``.
|
||||
matched_meta: dict[str, str] | None = None
|
||||
if parity_v2:
|
||||
# FIFO over indices 0,1,2…; first unassigned same-name
|
||||
# match wins. Handles parallel same-name calls (e.g. two
|
||||
# write_file calls) deterministically as long as the
|
||||
# model interleaves on_tool_start in the same order it
|
||||
# streamed the args.
|
||||
taken_ui_ids = set(ui_tool_call_id_by_run.values())
|
||||
for meta in index_to_meta.values():
|
||||
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||
matched_meta = meta
|
||||
break
|
||||
|
||||
tool_call_id: str
|
||||
langchain_tool_call_id: str | None = None
|
||||
if matched_meta is not None:
|
||||
tool_call_id = matched_meta["ui_id"]
|
||||
langchain_tool_call_id = matched_meta["lc_id"]
|
||||
# ``tool-input-start`` already fired during chunk
|
||||
# emission — skip the duplicate. No pruning is needed
|
||||
# because the chunk-emission loop intentionally never
|
||||
# appends registered-index chunks to
|
||||
# ``pending_tool_call_chunks`` (skip-append).
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
|
||||
else:
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
|
||||
# provider didn't stream tool_call_chunks for this call
|
||||
# (no index registered). Run the existing best-effort
|
||||
# match BEFORE emitting start so we still attach an
|
||||
# authoritative ``langchainToolCallId`` when possible.
|
||||
if parity_v2:
|
||||
langchain_tool_call_id = _legacy_match_lc_id(
|
||||
pending_tool_call_chunks,
|
||||
tool_name,
|
||||
run_id,
|
||||
lc_tool_call_id_by_run,
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
if run_id:
|
||||
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||
|
||||
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||
if isinstance(tool_input, dict):
|
||||
|
|
@ -600,6 +1023,7 @@ async def _stream_agent_events(
|
|||
tool_call_id,
|
||||
tool_name,
|
||||
_safe_input,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
|
|
@ -635,12 +1059,42 @@ async def _stream_agent_events(
|
|||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
# Look up the SAME card id used at on_tool_start (either the
|
||||
# parity_v2 lc-id-derived ui_id or the legacy synthetic
|
||||
# ``call_<run_id>``) so the output event always lands on the
|
||||
# same card as start/delta/available. Fallback preserves the
|
||||
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
|
||||
tool_call_id = ui_tool_call_id_by_run.get(
|
||||
run_id,
|
||||
f"call_{run_id[:32]}" if run_id else "call_unknown",
|
||||
)
|
||||
original_step_id = tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
)
|
||||
completed_step_ids.add(original_step_id)
|
||||
|
||||
# Authoritative LangChain tool_call_id from the returned
|
||||
# ``ToolMessage``. Falls back to whatever we matched
|
||||
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
||||
# if the output isn't a ToolMessage. The value is stored in
|
||||
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
||||
# picks it up for every output emit below.
|
||||
#
|
||||
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
|
||||
# card needs the LangChain id to match against the
|
||||
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
|
||||
# so the inline Revert button can light up. Reading
|
||||
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
|
||||
# access that is safe regardless of feature-flag state.
|
||||
current_lc_tool_call_id["value"] = None
|
||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||
if isinstance(authoritative, str) and authoritative:
|
||||
current_lc_tool_call_id["value"] = authoritative
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = authoritative
|
||||
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||
|
||||
if tool_name == "read_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
|
|
@ -676,6 +1130,41 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rm":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rmdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "mkdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Creating folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "move_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Moving file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_todos":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Planning tasks",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "save_document":
|
||||
result_str = (
|
||||
tool_output.get("result", "")
|
||||
|
|
@ -927,9 +1416,14 @@ async def _stream_agent_events(
|
|||
items=completed_items,
|
||||
)
|
||||
else:
|
||||
# Fallback completion title — see the matching in-progress
|
||||
# branch above for the wording rationale.
|
||||
fallback_title = (
|
||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title=f"Using {tool_name.replace('_', ' ')}",
|
||||
title=fallback_title,
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
|
|
@ -940,7 +1434,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = []
|
||||
|
||||
if tool_name == "generate_podcast":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -965,7 +1459,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_video_presentation":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -993,7 +1487,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_image":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1020,12 +1514,12 @@ async def _stream_agent_events(
|
|||
display_output["content_preview"] = (
|
||||
content[:500] + "..." if len(content) > 500 else content
|
||||
)
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
display_output,
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"result": tool_output},
|
||||
)
|
||||
|
|
@ -1053,7 +1547,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
result_text = _tool_output_to_text(tool_output)
|
||||
if _tool_output_has_error(tool_output):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "error",
|
||||
|
|
@ -1062,7 +1556,7 @@ async def _stream_agent_events(
|
|||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "completed",
|
||||
|
|
@ -1072,7 +1566,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
elif tool_name == "generate_report":
|
||||
# Stream the full report result so frontend can render the ReportCard
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1099,7 +1593,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_resume":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1150,7 +1644,7 @@ async def _stream_agent_events(
|
|||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1178,7 +1672,7 @@ async def _stream_agent_events(
|
|||
if fpath and fpath not in result.sandbox_files:
|
||||
result.sandbox_files.append(fpath)
|
||||
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
|
|
@ -1213,12 +1707,12 @@ async def _stream_agent_events(
|
|||
citations[chunk_url]["snippet"] = (
|
||||
content[:200] + "…" if len(content) > 200 else content
|
||||
)
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"status": "completed", "citations": citations},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"status": "completed", "result_length": len(str(tool_output))},
|
||||
)
|
||||
|
|
@ -1276,6 +1770,25 @@ async def _stream_agent_events(
|
|||
},
|
||||
)
|
||||
|
||||
elif event_type == "on_custom_event" and event.get("name") == "action_log":
|
||||
# Surface a freshly committed AgentActionLog row so the chat
|
||||
# tool card can render its Revert button immediately.
|
||||
data = event.get("data", {})
|
||||
if data.get("id") is not None:
|
||||
yield streaming_service.format_data("action-log", data)
|
||||
|
||||
elif (
|
||||
event_type == "on_custom_event"
|
||||
and event.get("name") == "action_log_updated"
|
||||
):
|
||||
# Reversibility flipped in kb_persistence after the SAVEPOINT
|
||||
# for a destructive op (rm/rmdir/move/edit/write) committed.
|
||||
# Frontend uses this to flip the card's Revert
|
||||
# button on without re-fetching the actions list.
|
||||
data = event.get("data", {})
|
||||
if data.get("id") is not None:
|
||||
yield streaming_service.format_data("action-log-updated", data)
|
||||
|
||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -1293,11 +1806,12 @@ async def _stream_agent_events(
|
|||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
# (dirty_paths / staged_dirs / pending_moves) will still be in the
|
||||
# checkpointed state. Run the SAME shared commit helper here so the
|
||||
# turn's writes don't get lost on client disconnect, then push the
|
||||
# delta back into the graph using `as_node=...` so reducers fire as if
|
||||
# the after_agent hook produced it.
|
||||
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
|
||||
# pending_dir_deletes) will still be in the checkpointed state. Run
|
||||
# the SAME shared commit helper here so the turn's writes don't get
|
||||
# lost on client disconnect, then push the delta back into the graph
|
||||
# using `as_node=...` so reducers fire as if the after_agent hook
|
||||
# produced it.
|
||||
if (
|
||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||
and fallback_commit_search_space_id is not None
|
||||
|
|
@ -1305,6 +1819,8 @@ async def _stream_agent_events(
|
|||
(state_values.get("dirty_paths") or [])
|
||||
or (state_values.get("staged_dirs") or [])
|
||||
or (state_values.get("pending_moves") or [])
|
||||
or (state_values.get("pending_deletes") or [])
|
||||
or (state_values.get("pending_dir_deletes") or [])
|
||||
)
|
||||
):
|
||||
try:
|
||||
|
|
@ -1313,6 +1829,7 @@ async def _stream_agent_events(
|
|||
search_space_id=fallback_commit_search_space_id,
|
||||
created_by_id=fallback_commit_created_by_id,
|
||||
filesystem_mode=fallback_commit_filesystem_mode,
|
||||
thread_id=fallback_commit_thread_id,
|
||||
dispatch_events=False,
|
||||
)
|
||||
if delta:
|
||||
|
|
@ -1753,13 +2270,33 @@ async def stream_new_chat(
|
|||
|
||||
config = {
|
||||
"configurable": configurable,
|
||||
"recursion_limit": 80, # Increase from default 25 to allow more tool iterations
|
||||
# Effectively uncapped, matching the agent-level
|
||||
# ``with_config`` default in ``chat_deepagent.create_agent``
|
||||
# and the unbounded ``while(true)`` loop used by OpenCode's
|
||||
# ``session/processor.ts``. Real circuit-breakers live in
|
||||
# middleware: ``DoomLoopMiddleware`` (sliding-window tool
|
||||
# signature check), plus ``enable_tool_call_limit`` /
|
||||
# ``enable_model_call_limit`` when those flags are set. The
|
||||
# original LangGraph default of 25 (and our previous 80
|
||||
# bump) hit users on legitimate multi-tool plans.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
# Start the message stream
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
# Surface the per-turn correlation id at the very start of the
|
||||
# stream so the frontend can stamp it onto the in-flight
|
||||
# assistant message and replay it via ``appendMessage``
|
||||
# for durable storage. Tool/action-log events DO carry it later,
|
||||
# but pure-text turns never produce action-log events; this
|
||||
# event guarantees the frontend learns the turn id regardless.
|
||||
yield streaming_service.format_data(
|
||||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_surfsense_docs:
|
||||
initial_title = "Analyzing referenced content"
|
||||
|
|
@ -1910,6 +2447,7 @@ async def stream_new_chat(
|
|||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
@ -2353,11 +2891,22 @@ async def stream_resume_chat(
|
|||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
},
|
||||
"recursion_limit": 80,
|
||||
# See ``stream_new_chat`` above for rationale: effectively
|
||||
# uncapped to mirror the agent default and OpenCode's
|
||||
# session loop. Doom-loop / call-limit middleware enforce
|
||||
# the real ceiling.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
# Same rationale as ``stream_new_chat``: emit the turn id so
|
||||
# resumed streams can be persisted with their correlation id
|
||||
# intact.
|
||||
yield streaming_service.format_data(
|
||||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
|
|
@ -2375,6 +2924,7 @@ async def stream_resume_chat(
|
|||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
|||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRuntime:
|
||||
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
|
||||
|
||||
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
|
||||
to populate the new ``chat_turn_id`` column (see migration 135).
|
||||
"""
|
||||
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRequest:
|
||||
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
||||
|
|
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
|
|||
"args": {"color": "red", "size": 3},
|
||||
"id": "tc-abc",
|
||||
},
|
||||
runtime=_FakeRuntime(
|
||||
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||
),
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
|
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
|
|||
assert row.error is None
|
||||
assert row.reverse_descriptor is None
|
||||
assert row.reversible is False
|
||||
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
|
||||
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
|
||||
assert row.tool_call_id == "tc-abc"
|
||||
assert row.turn_id == "tc-abc"
|
||||
assert row.chat_turn_id == "42:1700000000000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_turn_id_none_when_runtime_missing(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
|
||||
captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
|
||||
runtime=None,
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
row = captured["rows"][0]
|
||||
assert row.tool_call_id == "tc-1"
|
||||
assert row.chat_turn_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_row_on_failure_and_reraises(
|
||||
|
|
@ -293,6 +333,76 @@ class TestReverseDescriptor:
|
|||
assert row.reversible is False
|
||||
|
||||
|
||||
class TestActionLogDispatch:
|
||||
"""Verify ``adispatch_custom_event`` fires after commit."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_action_log_event_on_success(
|
||||
self, patch_get_flags, fake_session_factory
|
||||
) -> None:
|
||||
_captured, factory = fake_session_factory
|
||||
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||
request = _FakeRequest(
|
||||
tool_call={
|
||||
"name": "make_widget",
|
||||
"args": {"color": "red"},
|
||||
"id": "tc-evt",
|
||||
},
|
||||
runtime=_FakeRuntime(
|
||||
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||
),
|
||||
)
|
||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
|
||||
handler = AsyncMock(return_value=result_msg)
|
||||
|
||||
dispatch_mock = AsyncMock()
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||
patch(
|
||||
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||
dispatch_mock,
|
||||
),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
|
||||
dispatch_mock.assert_awaited_once()
|
||||
call_args = dispatch_mock.await_args
|
||||
assert call_args is not None
|
||||
assert call_args.args[0] == "action_log"
|
||||
payload = call_args.args[1]
|
||||
assert payload["lc_tool_call_id"] == "tc-evt"
|
||||
assert payload["chat_turn_id"] == "42:1700000000000"
|
||||
assert payload["tool_name"] == "make_widget"
|
||||
assert payload["reversible"] is False
|
||||
assert payload["reverse_descriptor_present"] is False
|
||||
assert payload["error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
|
||||
"""If commit fails the dispatch is suppressed (no row to surface)."""
|
||||
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||
request = _FakeRequest(
|
||||
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||
)
|
||||
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||
dispatch_mock = AsyncMock()
|
||||
|
||||
def _exploding_session():
|
||||
raise RuntimeError("DB is down")
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||
patch(
|
||||
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||
dispatch_mock,
|
||||
),
|
||||
):
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
dispatch_mock.assert_not_awaited()
|
||||
|
||||
|
||||
class TestArgsTruncation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_huge_args_payload_is_truncated(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for the desktop-mode safety ruleset.
|
||||
|
||||
In desktop mode the agent operates against the user's real disk with no
|
||||
revision history, so destructive filesystem operations must require
|
||||
explicit approval. These tests pin the set of tools that get the ``ask``
|
||||
gate so it cannot silently regress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
|
||||
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
|
||||
# copy here means the rule contract has a focused regression test even when
|
||||
# the larger graph-build helper is hard to instantiate in unit tests.
|
||||
DESKTOP_SAFETY_RULESET = Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
|
||||
SURFSENSE_DEFAULTS = Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
|
||||
|
||||
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
|
||||
rules = evaluate_many(tool_name, [tool_name], *rulesets)
|
||||
return aggregate_action(rules)
|
||||
|
||||
|
||||
class TestDesktopSafetyRulesGateDestructiveOps:
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
["rm", "rmdir", "move_file", "edit_file", "write_file"],
|
||||
)
|
||||
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
|
||||
# surfsense_defaults says "allow */*"; desktop_safety must override
|
||||
# because it's layered later (last-match-wins).
|
||||
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "ask", (
|
||||
f"{tool_name} must require approval in desktop mode "
|
||||
f"(no revert path on real disk); got {action!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
|
||||
)
|
||||
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
|
||||
# Read-only and trivially-reversible tools must NOT get gated —
|
||||
# otherwise every navigation in desktop mode pops an interrupt.
|
||||
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "allow", (
|
||||
f"{tool_name} should not be gated in desktop mode; got {action!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestDesktopSafetyOverridesAllowDefault:
|
||||
def test_layer_order_last_match_wins(self) -> None:
|
||||
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
|
||||
# default would win and the safety net would be inert. This test
|
||||
# protects against accidentally swapping the rulesets in
|
||||
# ``_build_compiled_agent_blocking``.
|
||||
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
|
||||
# Layered "wrong way" — the broad allow now wins.
|
||||
assert action == "allow"
|
||||
|
||||
# Correct order: defaults < desktop_safety -> ask wins.
|
||||
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "ask"
|
||||
|
||||
|
||||
class TestPermissionMiddlewareIntegration:
|
||||
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.new_chat.errors import RejectedError
|
||||
|
||||
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
|
||||
# Stub the interrupt to a "reject" decision so we can assert the
|
||||
# ask path was taken without spinning up the LangGraph runtime.
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/Users/me/Documents/important.docx"},
|
||||
"id": "tc-rm",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
class _FakeRuntime:
|
||||
config: dict = {"configurable": {"thread_id": "test"}}
|
||||
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for the default auto-approval list in ``hitl.request_approval``.
|
||||
|
||||
These pin the policy that low-stakes connector creation tools (drafts,
|
||||
new-file creates) skip the HITL interrupt by default. Without this set,
|
||||
every "draft my newsletter" turn used to fire ~3 interrupts before any
|
||||
useful work happened.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools.hitl import (
|
||||
DEFAULT_AUTO_APPROVED_TOOLS,
|
||||
HITLResult,
|
||||
request_approval,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestDefaultAutoApprovedToolsList:
|
||||
def test_set_contains_expected_creation_tools(self) -> None:
|
||||
# If anyone changes the policy list, we want a single test to
|
||||
# update so the contract is explicit. Keep this in sync with
|
||||
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
|
||||
expected = {
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
"create_dropbox_file",
|
||||
"create_onedrive_file",
|
||||
}
|
||||
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
|
||||
|
||||
def test_set_is_immutable(self) -> None:
|
||||
# frozenset prevents accidental at-runtime mutation that would
|
||||
# silently widen the auto-approval surface.
|
||||
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
|
||||
|
||||
def test_send_tools_are_not_auto_approved(self) -> None:
|
||||
# External-broadcast tools must always prompt.
|
||||
for tool_name in (
|
||||
"send_gmail_email",
|
||||
"send_discord_message",
|
||||
"send_teams_message",
|
||||
"delete_notion_page",
|
||||
"create_calendar_event",
|
||||
"delete_calendar_event",
|
||||
):
|
||||
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
|
||||
f"{tool_name} must remain HITL-gated"
|
||||
)
|
||||
|
||||
|
||||
class TestRequestApprovalAutoBypass:
|
||||
def test_auto_approved_tool_skips_interrupt(self) -> None:
|
||||
# No interrupt mock set up — if the function attempted to call
|
||||
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
|
||||
# The fact that we get a clean HITLResult proves the bypass.
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||
)
|
||||
assert isinstance(result, HITLResult)
|
||||
assert result.rejected is False
|
||||
assert result.decision_type == "auto_approved"
|
||||
# Original params are preserved untouched (no user edits possible).
|
||||
assert result.params == {
|
||||
"to": "alice@example.com",
|
||||
"subject": "hi",
|
||||
"body": "hey",
|
||||
}
|
||||
|
||||
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
|
||||
# A tool NOT in the default list must reach ``langgraph.interrupt``.
|
||||
# Outside a runnable context that call raises a RuntimeError —
|
||||
# which is exactly the signal we want: the bypass did NOT fire.
|
||||
with pytest.raises(RuntimeError, match="runnable context"):
|
||||
request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||
)
|
||||
|
||||
def test_user_trusted_tools_still_take_precedence(self) -> None:
|
||||
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
|
||||
# was checked BEFORE the default list and must keep working
|
||||
# for tools outside the default list.
|
||||
result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name="my_custom_mcp_tool",
|
||||
params={"x": 1},
|
||||
trusted_tools=["my_custom_mcp_tool"],
|
||||
)
|
||||
assert result.decision_type == "trusted"
|
||||
assert result.rejected is False
|
||||
|
||||
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
|
||||
# When trusted_tools is empty and tool is in the default list,
|
||||
# we should still bypass — proves the order in request_approval.
|
||||
result = request_approval(
|
||||
action_type="notion_page_creation",
|
||||
tool_name="create_notion_page",
|
||||
params={"title": "Plan"},
|
||||
trusted_tools=[],
|
||||
)
|
||||
assert result.decision_type == "auto_approved"
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
|
||||
|
||||
The tools build ``Command(update=...)`` payloads that the persistence
|
||||
middleware applies at end of turn. These tests stub out the backend and
|
||||
runtime to assert the staging payload shape:
|
||||
|
||||
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
|
||||
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
|
||||
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
|
||||
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
|
||||
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
|
||||
* ``KBPostgresBackend`` view-helpers honor staged deletes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||
middleware._filesystem_mode = mode
|
||||
middleware._custom_tool_descriptions = {}
|
||||
return middleware
|
||||
|
||||
|
||||
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
|
||||
state = state or {}
|
||||
state.setdefault("cwd", "/documents")
|
||||
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
|
||||
|
||||
|
||||
class _KBBackendStub(KBPostgresBackend):
|
||||
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
|
||||
|
||||
We bypass the real ``__init__`` (which expects a runtime + DB session)
|
||||
and inject just the methods the rm/rmdir tools touch. The class
|
||||
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
|
||||
inside the tools happy, which is what gates them from the desktop
|
||||
code path.
|
||||
"""
|
||||
|
||||
def __init__(self, *, children=None, file_data=None) -> None:
|
||||
self.als_info = AsyncMock(return_value=children or [])
|
||||
self._load_file_data = AsyncMock(
|
||||
return_value=(file_data, 17) if file_data is not None else None
|
||||
)
|
||||
|
||||
|
||||
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
|
||||
return _KBBackendStub(children=children, file_data=file_data)
|
||||
|
||||
|
||||
def _bind_backend(middleware, backend):
|
||||
"""Inject a backend resolver onto the middleware test instance."""
|
||||
middleware._get_backend = lambda runtime: backend
|
||||
return backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rm
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmStaging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stages_delete_and_tombstones_state(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
"files": {"/documents/notes.md": {"content": ["hello"]}},
|
||||
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||
},
|
||||
tool_call_id="tc-1",
|
||||
)
|
||||
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update"), f"expected Command, got {result!r}"
|
||||
update = result.update
|
||||
assert update["pending_deletes"] == [
|
||||
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
|
||||
]
|
||||
assert update["files"] == {"/documents/notes.md": None}
|
||||
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_documents_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rm" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rm" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_directory_via_staged_dirs(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime(
|
||||
{
|
||||
"staged_dirs": ["/documents/team-x"],
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/team-x", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "directory" in result.lower()
|
||||
assert "rmdir" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_directory_via_listing(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(
|
||||
m,
|
||||
_make_backend_stub(
|
||||
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
|
||||
),
|
||||
)
|
||||
runtime = _runtime()
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/foo", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "directory" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_anonymous_doc(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime(
|
||||
{
|
||||
"kb_anon_doc": {
|
||||
"path": "/documents/uploaded.xml",
|
||||
"title": "uploaded",
|
||||
"content": "",
|
||||
"chunks": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "read-only" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drops_path_from_dirty_paths(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"files": {"/documents/notes.md": {"content": ["x"]}},
|
||||
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||
"dirty_paths": ["/documents/notes.md"],
|
||||
}
|
||||
)
|
||||
tool = m._create_rm_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
update = result.update
|
||||
# First element is _CLEAR sentinel; the rest must NOT contain the
|
||||
# rm'd path.
|
||||
dirty = update.get("dirty_paths") or []
|
||||
assert "/documents/notes.md" not in dirty[1:]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rmdir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmdirStaging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stages_dir_delete_when_empty_and_db_backed(self):
|
||||
m = _make_middleware()
|
||||
backend = _bind_backend(m, _make_backend_stub(children=[]))
|
||||
# Override _load_file_data to return None (folder, not a file) and
|
||||
# parent listing to claim the folder exists.
|
||||
backend._load_file_data = AsyncMock(return_value=None)
|
||||
backend.als_info = AsyncMock(
|
||||
side_effect=[
|
||||
[], # children of /documents/proj
|
||||
[
|
||||
{"path": "/documents/proj", "is_dir": True},
|
||||
], # parent listing
|
||||
]
|
||||
)
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
},
|
||||
tool_call_id="tc-rd",
|
||||
)
|
||||
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
assert update["pending_dir_deletes"] == [
|
||||
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_non_empty(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(
|
||||
m,
|
||||
_make_backend_stub(
|
||||
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
|
||||
),
|
||||
)
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "not empty" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unstages_same_turn_mkdir(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[]))
|
||||
runtime = _runtime(
|
||||
{
|
||||
"cwd": "/documents",
|
||||
"staged_dirs": ["/documents/scratch"],
|
||||
},
|
||||
tool_call_id="tc-rd",
|
||||
)
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/scratch", runtime=runtime)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
assert "pending_dir_deletes" not in update
|
||||
# _CLEAR sentinel + remaining items (in this case, none).
|
||||
staged_after = update["staged_dirs"]
|
||||
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||
assert "/documents/scratch" not in staged_after[1:]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_root(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
for victim in ("/", "/documents"):
|
||||
result = await tool.coroutine(victim, runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "refusing to rmdir" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/proj"})
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "cwd" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_ancestor_of_cwd(self):
|
||||
m = _make_middleware()
|
||||
runtime = _runtime({"cwd": "/documents/proj/sub"})
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "cwd" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_files(self):
|
||||
m = _make_middleware()
|
||||
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||
runtime = _runtime()
|
||||
tool = m._create_rmdir_tool()
|
||||
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||
assert isinstance(result, str)
|
||||
assert "is a file" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KBPostgresBackend view filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKBPostgresBackendDeleteFilter:
|
||||
"""als_info / glob / grep should suppress paths queued for delete."""
|
||||
|
||||
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
|
||||
runtime = SimpleNamespace(state=state)
|
||||
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
|
||||
return backend
|
||||
|
||||
def test_pending_filesystem_view_returns_deleted_paths(self):
|
||||
backend = self._make_backend(
|
||||
{
|
||||
"pending_deletes": [
|
||||
{"path": "/documents/x.md", "tool_call_id": "t1"},
|
||||
],
|
||||
"pending_dir_deletes": [
|
||||
{"path": "/documents/d1", "tool_call_id": "t2"},
|
||||
],
|
||||
}
|
||||
)
|
||||
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
|
||||
assert "/documents/x.md" in removed
|
||||
assert "/documents/d1" in deleted_dirs
|
||||
assert alias == {}
|
||||
|
||||
def test_dir_suppressed_covers_descendants(self):
|
||||
backend = self._make_backend({})
|
||||
deleted_dirs = {"/documents/d"}
|
||||
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
|
||||
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
|
||||
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
|
||||
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)
|
||||
|
|
@ -98,10 +98,54 @@ class TestInitialFilesystemState:
|
|||
state = _initial_filesystem_state()
|
||||
assert state["cwd"] == "/documents"
|
||||
assert state["staged_dirs"] == []
|
||||
assert state["staged_dir_tool_calls"] == {}
|
||||
assert state["pending_moves"] == []
|
||||
assert state["pending_deletes"] == []
|
||||
assert state["pending_dir_deletes"] == []
|
||||
assert state["doc_id_by_path"] == {}
|
||||
assert state["dirty_paths"] == []
|
||||
assert state["dirty_path_tool_calls"] == {}
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_matched_chunk_ids"] == {}
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
||||
|
||||
class TestMultiEditSamePathCoalescing:
|
||||
"""Multi-edit-same-path turns must coalesce into ONE binding record.
|
||||
|
||||
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
|
||||
tool_call_id that produced the current state on disk. Because
|
||||
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
|
||||
edit doesn't append a new path entry — and because
|
||||
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
|
||||
overwrite, the LATEST tool_call_id wins. That's the correct behavior
|
||||
for snapshotting: revert restores to the pre-mutation state, and
|
||||
multiple back-to-back edits in one turn coalesce into a single
|
||||
revisible op (the user sees ONE Revert button per turn-per-path,
|
||||
not N).
|
||||
"""
|
||||
|
||||
def test_dirty_paths_dedupes_repeated_writes(self):
|
||||
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
|
||||
# to the same path produce one entry, not two.
|
||||
first = _add_unique_reducer([], ["/documents/a.md"])
|
||||
second = _add_unique_reducer(first, ["/documents/a.md"])
|
||||
assert second == ["/documents/a.md"]
|
||||
|
||||
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
|
||||
# First write tags the path with tcid-1.
|
||||
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
|
||||
# Second write to the same path tags it with tcid-2 (latest wins).
|
||||
merged = _dict_merge_with_tombstones_reducer(
|
||||
merged, {"/documents/a.md": "tcid-2"}
|
||||
)
|
||||
assert merged == {"/documents/a.md": "tcid-2"}
|
||||
|
||||
def test_rm_tombstones_dirty_path_tool_call(self):
|
||||
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
|
||||
# prevent a stale binding from leaking past the delete.
|
||||
merged = _dict_merge_with_tombstones_reducer(
|
||||
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
|
||||
)
|
||||
assert merged == {}
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
|
||||
|
||||
A full apply/rollback test would require a live Postgres; here we verify
|
||||
the migration module's static contract:
|
||||
|
||||
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
|
||||
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
|
||||
(one for ``document_revisions.document_id``, one for
|
||||
``folder_revisions.folder_id``).
|
||||
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
|
||||
orphaned revisions.
|
||||
|
||||
If any of these invariants regress the snapshot/revert pipeline silently
|
||||
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
|
||||
migration "down" or never ran it at all.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_MIGRATION_PATH = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "134_relax_revision_fks.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_migration():
|
||||
"""Load the migration module by file path (no package import needed)."""
|
||||
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
|
||||
assert spec and spec.loader, "could not load migration spec"
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_migration_chain_revision_ids() -> None:
|
||||
module = _load_migration()
|
||||
# The migration file uses short numeric revision IDs to match the
|
||||
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
|
||||
# filename is documentation, not the canonical revision string.
|
||||
assert getattr(module, "revision", None) == "134"
|
||||
assert getattr(module, "down_revision", None) == "133"
|
||||
|
||||
|
||||
def test_migration_exposes_upgrade_and_downgrade() -> None:
|
||||
module = _load_migration()
|
||||
upgrade = getattr(module, "upgrade", None)
|
||||
downgrade = getattr(module, "downgrade", None)
|
||||
assert callable(upgrade), "upgrade() is required"
|
||||
assert callable(downgrade), "downgrade() is required"
|
||||
|
||||
|
||||
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
|
||||
module = _load_migration()
|
||||
src = inspect.getsource(module.upgrade)
|
||||
assert "document_revisions" in src
|
||||
assert "folder_revisions" in src
|
||||
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
|
||||
# of the migration: snapshots must outlive their parent row.
|
||||
assert src.count('ondelete="SET NULL"') >= 2
|
||||
# And the ``document_id`` / ``folder_id`` columns become nullable.
|
||||
assert "nullable=True" in src
|
||||
|
||||
|
||||
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
|
||||
module = _load_migration()
|
||||
src = inspect.getsource(module.downgrade)
|
||||
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
|
||||
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
|
||||
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
|
||||
# Then restore the original CASCADE/NOT NULL contract.
|
||||
assert src.count('ondelete="CASCADE"') >= 2
|
||||
assert "nullable=False" in src
|
||||
|
|
@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
|
|||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
|
|
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
|
|||
"edit_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"grep",
|
||||
):
|
||||
|
|
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
|
|||
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||
|
||||
def test_cloud_descs_include_rm_and_rmdir(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||
assert "rm" in descs and "rmdir" in descs
|
||||
assert "Deletes a single file" in descs["rm"]
|
||||
assert "Deletes an empty directory" in descs["rmdir"]
|
||||
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
|
||||
|
||||
def test_desktop_descs_warn_about_irreversibility(self):
|
||||
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||
assert "NOT reversible" in descs["rm"]
|
||||
assert "NOT reversible" in descs["rmdir"]
|
||||
|
||||
def test_sandbox_addendum_appended_when_available(self):
|
||||
prompt = _build_filesystem_system_prompt(
|
||||
FilesystemMode.CLOUD, sandbox_available=True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,309 @@
|
|||
"""Unit tests for the kb_persistence snapshot helpers.
|
||||
|
||||
The full ``commit_staged_filesystem_state`` body exercises a real session
|
||||
in integration tests; here we verify the building blocks used by the
|
||||
snapshot/revert pipeline:
|
||||
|
||||
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
|
||||
(regression guard against the N+1 lookup pattern).
|
||||
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
|
||||
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
|
||||
shape the snapshot helpers consume.
|
||||
|
||||
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
|
||||
the assertions run in milliseconds and don't require Postgres.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware import kb_persistence
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||
self._rows = rows or []
|
||||
self._scalar = scalar
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self._rows)
|
||||
|
||||
def scalar_one_or_none(self) -> Any:
|
||||
return self._scalar
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_issues_single_query() -> None:
|
||||
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
|
||||
session = _FakeSession()
|
||||
session.execute.return_value = _FakeResult(
|
||||
rows=[
|
||||
MagicMock(id=11, tool_call_id="tc-a"),
|
||||
MagicMock(id=22, tool_call_id="tc-b"),
|
||||
MagicMock(id=33, tool_call_id="tc-c"),
|
||||
]
|
||||
)
|
||||
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=1,
|
||||
tool_call_ids={"tc-a", "tc-b", "tc-c"},
|
||||
)
|
||||
|
||||
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
|
||||
assert session.execute.await_count == 1, (
|
||||
"Snapshot binding must batch into ONE query; got "
|
||||
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
|
||||
session = _FakeSession()
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=None,
|
||||
tool_call_ids={"tc-a"},
|
||||
)
|
||||
assert mapping == {}
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
|
||||
session = _FakeSession()
|
||||
mapping = await kb_persistence._find_action_ids_batch(
|
||||
session, # type: ignore[arg-type]
|
||||
thread_id=42,
|
||||
tool_call_ids=set(),
|
||||
)
|
||||
assert mapping == {}
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
|
||||
session = _FakeSession()
|
||||
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
|
||||
assert session.execute.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
|
||||
session = _FakeSession()
|
||||
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
|
||||
assert session.execute.await_count == 1
|
||||
|
||||
|
||||
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
|
||||
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
|
||||
doc = MagicMock()
|
||||
doc.content = "body"
|
||||
doc.title = "notes.md"
|
||||
doc.folder_id = 7
|
||||
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
|
||||
|
||||
payload = kb_persistence._doc_revision_payload(
|
||||
doc, chunks_before=[{"content": "x"}]
|
||||
)
|
||||
|
||||
assert payload["title_before"] == "notes.md"
|
||||
assert payload["folder_id_before"] == 7
|
||||
assert payload["content_before"] == "body"
|
||||
assert payload["chunks_before"] == [{"content": "x"}]
|
||||
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
|
||||
|
||||
|
||||
def test_doc_revision_payload_handles_missing_metadata() -> None:
|
||||
doc = MagicMock()
|
||||
doc.content = ""
|
||||
doc.title = ""
|
||||
doc.folder_id = None
|
||||
doc.document_metadata = None
|
||||
payload = kb_persistence._doc_revision_payload(doc)
|
||||
assert payload["metadata_before"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
|
||||
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
|
||||
session = _FakeSession()
|
||||
session.execute.return_value = _FakeResult(
|
||||
rows=[
|
||||
MagicMock(content="alpha"),
|
||||
MagicMock(content="beta"),
|
||||
]
|
||||
)
|
||||
chunks = await kb_persistence._load_chunks_for_snapshot(
|
||||
session,
|
||||
doc_id=42, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deferred reversibility-flip dispatches.
|
||||
#
|
||||
# The snapshot helpers used to dispatch ``action_log_updated`` directly
|
||||
# from inside the SAVEPOINT block. That meant the SSE side-channel
|
||||
# could tell the UI a row was reversible while the OUTER transaction
|
||||
# was still pending — and if the outer commit failed, every SAVEPOINT
|
||||
# rolled back too, leaving the UI in a state inconsistent with
|
||||
# durable storage. The deferred-dispatch contract fixes that:
|
||||
#
|
||||
# • when a ``deferred_dispatches`` list is provided, the helper
|
||||
# APPENDS the action_id and does NOT dispatch;
|
||||
# • the caller (``commit_staged_filesystem_state``) flushes the list
|
||||
# only AFTER ``await session.commit()`` succeeds; on rollback it
|
||||
# clears the list so nothing is emitted.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _NestedCtx:
|
||||
"""Async context manager mimicking ``session.begin_nested()``."""
|
||||
|
||||
async def __aenter__(self) -> _NestedCtx:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 17
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
deferred: list[int] = []
|
||||
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
|
||||
doc.title = "x.md"
|
||||
doc.folder_id = None
|
||||
doc.content = "body"
|
||||
|
||||
rev_id = await kb_persistence._snapshot_document_pre_write(
|
||||
session, # type: ignore[arg-type]
|
||||
doc=doc,
|
||||
action_id=42,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
deferred_dispatches=deferred,
|
||||
)
|
||||
|
||||
assert rev_id == 17
|
||||
# Inline dispatch must NOT have fired; the action_id is queued.
|
||||
assert dispatched == []
|
||||
assert deferred == [42]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 7
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
|
||||
doc.title = "y.md"
|
||||
doc.folder_id = None
|
||||
doc.content = "body"
|
||||
|
||||
await kb_persistence._snapshot_document_pre_write(
|
||||
session, # type: ignore[arg-type]
|
||||
doc=doc,
|
||||
action_id=88,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
# No deferred_dispatches arg — fall back to inline dispatch.
|
||||
)
|
||||
|
||||
assert dispatched == [88]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
|
||||
session = MagicMock()
|
||||
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||
session.execute = AsyncMock() # _mark_action_reversible calls execute
|
||||
session.flush = AsyncMock()
|
||||
|
||||
def _add(rev: Any) -> None:
|
||||
rev.id = 3
|
||||
|
||||
session.add = MagicMock(side_effect=_add)
|
||||
|
||||
dispatched: list[int] = []
|
||||
|
||||
async def _fake_dispatch(action_id: int | None) -> None:
|
||||
if action_id is not None:
|
||||
dispatched.append(int(action_id))
|
||||
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||
)
|
||||
|
||||
deferred: list[int] = []
|
||||
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
|
||||
|
||||
await kb_persistence._snapshot_folder_pre_mkdir(
|
||||
session, # type: ignore[arg-type]
|
||||
folder=folder,
|
||||
action_id=55,
|
||||
search_space_id=1,
|
||||
turn_id="t-1",
|
||||
deferred_dispatches=deferred,
|
||||
)
|
||||
|
||||
assert dispatched == []
|
||||
assert deferred == [55]
|
||||
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
|
||||
|
||||
The empty-folder marker is critical UX: without it, the LLM cannot
|
||||
distinguish a leaf folder containing one document from a leaf folder
|
||||
that has no descendants at all, and ends up firing ``rmdir`` on
|
||||
non-empty folders. These tests pin the rendering contract so that
|
||||
contract cannot silently regress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
|
||||
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
|
||||
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
|
||||
class TestComputeNonEmptyFolders:
|
||||
def test_folder_with_direct_document_is_non_empty(self):
|
||||
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||
doc_paths = [
|
||||
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
|
||||
]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
|
||||
|
||||
def test_truly_empty_leaf_folder_is_not_non_empty(self):
|
||||
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||
doc_paths: list[str] = []
|
||||
assert _compute(folder_paths, doc_paths) == set()
|
||||
|
||||
def test_documents_propagate_up_to_all_ancestors(self):
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/A",
|
||||
f"{DOCUMENTS_ROOT}/A/B",
|
||||
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||
]
|
||||
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
assert non_empty == {
|
||||
f"{DOCUMENTS_ROOT}/A",
|
||||
f"{DOCUMENTS_ROOT}/A/B",
|
||||
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||
}
|
||||
|
||||
def test_chain_with_subfolders_marks_only_leaf_empty(self):
|
||||
# POSIX-like semantic: a folder is "empty" only if it has no
|
||||
# immediate children (docs OR sub-folders). The model needs this
|
||||
# because parallel ``rmdir`` calls all see the same starting state,
|
||||
# so trying to rmdir a parent before its children is never safe.
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/X",
|
||||
f"{DOCUMENTS_ROOT}/X/Y",
|
||||
f"{DOCUMENTS_ROOT}/X/Y/Z",
|
||||
]
|
||||
non_empty = _compute(folder_paths, [])
|
||||
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
|
||||
# sub-folder child, so they are non-empty and should NOT carry the
|
||||
# ``(empty)`` marker.
|
||||
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
|
||||
|
||||
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
|
||||
# Mirrors a real DB layout where every intermediate folder is
|
||||
# materialized in the ``folders`` table.
|
||||
folder_paths = [
|
||||
f"{DOCUMENTS_ROOT}/Travel",
|
||||
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
|
||||
f"{DOCUMENTS_ROOT}/Travel/Notes",
|
||||
]
|
||||
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
|
||||
non_empty = _compute(folder_paths, doc_paths)
|
||||
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
|
||||
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
|
||||
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
|
||||
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
|
||||
|
||||
|
||||
class TestFormatTreeRendering:
|
||||
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
|
||||
|
||||
def _render(
|
||||
self,
|
||||
folder_paths: list[str],
|
||||
doc_specs: list[dict],
|
||||
) -> str:
|
||||
from app.agents.new_chat.path_resolver import PathIndex
|
||||
|
||||
index = PathIndex(
|
||||
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
|
||||
)
|
||||
|
||||
class _Row:
|
||||
def __init__(self, **kw):
|
||||
self.__dict__.update(kw)
|
||||
|
||||
docs = [_Row(**spec) for spec in doc_specs]
|
||||
|
||||
mw = KnowledgeTreeMiddleware(
|
||||
search_space_id=1,
|
||||
filesystem_mode=None, # type: ignore[arg-type]
|
||||
)
|
||||
return mw._format_tree(index, docs)
|
||||
|
||||
def test_renders_empty_marker_only_for_truly_empty_folders(self):
|
||||
# Reproduces the failure scenario from the bug report:
|
||||
# ``Boarding Pass`` is empty (its only doc was just deleted), while
|
||||
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
|
||||
# folders are present in the index, mirroring the real DB layout.
|
||||
folder_paths = [
|
||||
"/documents/File Upload",
|
||||
"/documents/File Upload/2026-04-08",
|
||||
"/documents/File Upload/2026-04-08/Travel",
|
||||
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
|
||||
"/documents/File Upload/2026-04-15",
|
||||
"/documents/File Upload/2026-04-15/Finance",
|
||||
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
|
||||
]
|
||||
tax_returns_folder_id = (
|
||||
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
|
||||
+ 1
|
||||
)
|
||||
rendered = self._render(
|
||||
folder_paths=folder_paths,
|
||||
doc_specs=[
|
||||
{
|
||||
"id": 100,
|
||||
"title": "federal.pdf",
|
||||
"folder_id": tax_returns_folder_id,
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Boarding Pass/ (empty)" in rendered
|
||||
assert "Tax Returns/ (empty)" not in rendered
|
||||
# Intermediate ancestors of the doc must NOT be marked empty.
|
||||
assert "Finance/ (empty)" not in rendered
|
||||
assert "2026-04-15/ (empty)" not in rendered
|
||||
|
|
@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
|
|||
assert write.error is not None
|
||||
assert "parent directory" in write.error
|
||||
assert not (tmp_path / "tempoo").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_success(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "delete-me.md").write_text("bye")
|
||||
|
||||
res = backend.delete_file("/delete-me.md")
|
||||
assert res.error is None
|
||||
assert res.path == "/delete-me.md"
|
||||
assert not (tmp_path / "delete-me.md").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "subdir").mkdir()
|
||||
|
||||
res = backend.delete_file("/subdir")
|
||||
assert res.error is not None
|
||||
assert "directory" in res.error
|
||||
assert (tmp_path / "subdir").exists()
|
||||
|
||||
|
||||
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
res = backend.delete_file("/nope.md")
|
||||
assert res.error is not None
|
||||
assert "not found" in res.error
|
||||
|
||||
|
||||
def test_local_backend_rmdir_success(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "empty").mkdir()
|
||||
|
||||
res = backend.rmdir("/empty")
|
||||
assert res.error is None
|
||||
assert res.path == "/empty"
|
||||
assert not (tmp_path / "empty").exists()
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "withkid").mkdir()
|
||||
(tmp_path / "withkid" / "child.md").write_text("x")
|
||||
|
||||
res = backend.rmdir("/withkid")
|
||||
assert res.error is not None
|
||||
assert "not empty" in res.error
|
||||
assert (tmp_path / "withkid" / "child.md").exists()
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
(tmp_path / "f.md").write_text("x")
|
||||
|
||||
res = backend.rmdir("/f.md")
|
||||
assert res.error is not None
|
||||
assert "not a directory" in res.error
|
||||
|
||||
|
||||
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
|
||||
"""``rmdir /`` MUST fail. The exact error wording comes from
|
||||
``_resolve_virtual`` (root resolves to outside the sandbox); what
|
||||
matters is that the call returns an error and does NOT delete the
|
||||
sandbox root on disk."""
|
||||
backend = LocalFolderBackend(str(tmp_path))
|
||||
|
||||
res = backend.rmdir("/")
|
||||
assert res.error is not None
|
||||
assert "Invalid path" in res.error or "root" in res.error
|
||||
assert tmp_path.exists()
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
|
||||
|
||||
The regenerate route's edit-from-position path introduces:
|
||||
* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples
|
||||
newest-first and picks the first one whose ``metadata["turn_id"]``
|
||||
differs from the edited turn. That checkpoint is the rewind target
|
||||
(state immediately before the edited turn started).
|
||||
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
|
||||
with a validator that prevents callers from requesting a revert pass
|
||||
without specifying which turn to roll back.
|
||||
|
||||
These are pure-Python helpers that don't need a live DB, so we exercise
|
||||
them with a small ``CheckpointTuple``-shaped namespace and direct
|
||||
schema instantiation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
|
||||
from app.schemas.new_chat import RegenerateRequest
|
||||
|
||||
|
||||
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
|
||||
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
|
||||
return SimpleNamespace(
|
||||
config={"configurable": {"checkpoint_id": checkpoint_id}},
|
||||
metadata={"turn_id": turn_id} if turn_id is not None else {},
|
||||
)
|
||||
|
||||
|
||||
class TestFindPreTurnCheckpointId:
|
||||
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
|
||||
# Newest-first: T2 is the most-recent turn. The latest non-T2
|
||||
# checkpoint (cp2) is the rewind target — state immediately
|
||||
# before T2 began.
|
||||
tuples = [
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
|
||||
# Regression for the bug where walking newest-first returned the
|
||||
# FIRST cp with ``turn_id != target`` — which is one of the
|
||||
# later-turn checkpoints, NOT the pre-turn boundary. Editing
|
||||
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
|
||||
# latest T3 checkpoint (cp6).
|
||||
tuples = [
|
||||
_cp("cp6", "T3"),
|
||||
_cp("cp5", "T3"),
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_returns_none_when_editing_first_turn(self) -> None:
|
||||
# No pre-turn boundary exists; caller is expected to fall back
|
||||
# to the oldest checkpoint or special-case "first turn of the
|
||||
# thread".
|
||||
tuples = [
|
||||
_cp("cp4", "T2"),
|
||||
_cp("cp3", "T2"),
|
||||
_cp("cp2", "T1"),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
|
||||
|
||||
def test_returns_none_when_only_edited_turn_present(self) -> None:
|
||||
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
|
||||
|
||||
def test_returns_none_for_empty_history(self) -> None:
|
||||
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
|
||||
|
||||
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
|
||||
# Checkpoints written before migration 136 have no
|
||||
# ``metadata.turn_id``. They should be eligible rewind targets
|
||||
# — they came before the
|
||||
# edited turn began.
|
||||
tuples = [
|
||||
_cp("cp3", "T2"),
|
||||
SimpleNamespace(
|
||||
config={"configurable": {"checkpoint_id": "cp2"}},
|
||||
metadata=None,
|
||||
),
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
|
||||
# then cp3(T2) crosses the boundary -> return cp2.
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||
|
||||
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
|
||||
# If a checkpoint tuple's ``config["configurable"]`` is missing
|
||||
# the ``checkpoint_id`` key (corrupt / partial), we keep the
|
||||
# last known good target instead of crashing.
|
||||
broken = SimpleNamespace(
|
||||
config={"configurable": {}}, metadata={"turn_id": "T1"}
|
||||
)
|
||||
tuples = [
|
||||
_cp("cp3", "T2"),
|
||||
broken,
|
||||
_cp("cp1", "T1"),
|
||||
]
|
||||
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
|
||||
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
|
||||
|
||||
|
||||
class TestRegenerateRequestValidation:
|
||||
def test_revert_actions_requires_from_message_id(self) -> None:
|
||||
with pytest.raises(Exception) as exc:
|
||||
RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
revert_actions=True,
|
||||
)
|
||||
msg = str(exc.value).lower()
|
||||
assert "from_message_id" in msg
|
||||
|
||||
def test_from_message_id_without_revert_is_allowed(self) -> None:
|
||||
req = RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
from_message_id=42,
|
||||
)
|
||||
assert req.from_message_id == 42
|
||||
assert req.revert_actions is False
|
||||
|
||||
def test_revert_actions_with_from_message_id_passes(self) -> None:
|
||||
req = RegenerateRequest(
|
||||
search_space_id=1,
|
||||
user_query="hi",
|
||||
from_message_id=42,
|
||||
revert_actions=True,
|
||||
)
|
||||
assert req.revert_actions is True
|
||||
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||
|
||||
The per-turn batch revert route walks rows in reverse ``created_at``
|
||||
order, reverts each independently, and returns a per-action result
|
||||
list. Partial success is normal — the response status
|
||||
is ``"partial"`` whenever any row could not be reverted, but we never
|
||||
collapse the whole batch into a 4xx.
|
||||
|
||||
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
|
||||
session, so they exercise the route's dispatch logic without a real DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.routes import agent_revert_route
|
||||
from app.services.revert_service import RevertOutcome
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeAction:
|
||||
id: int
|
||||
tool_name: str
|
||||
user_id: str | None = "u1"
|
||||
reverse_of: int | None = None
|
||||
error: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeUser:
|
||||
id: str = "u1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ScalarResult:
|
||||
rows: list[Any]
|
||||
|
||||
def first(self) -> Any:
|
||||
return self.rows[0] if self.rows else None
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Result:
|
||||
rows: list[Any] = field(default_factory=list)
|
||||
|
||||
def scalars(self) -> _ScalarResult:
|
||||
return _ScalarResult(self.rows)
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
|
||||
# the row-tuple result (no ``.scalars()`` indirection). The
|
||||
# rows queued for that helper are list[(revert_id, original_id)].
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
class _FakeNestedCtx:
|
||||
"""Async context manager that mimics ``session.begin_nested()``.
|
||||
|
||||
The route raises a sentinel exception inside this block to roll back
|
||||
bad rows. We just pass the exception through.
|
||||
"""
|
||||
|
||||
async def __aenter__(self) -> _FakeNestedCtx:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
# Returning False (or None) propagates the exception; the route
|
||||
# catches its own sentinel above this layer.
|
||||
return False
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stand-in for the revert-turn route.
|
||||
|
||||
Holds a queue of result objects; each ``execute(...)`` pops the next
|
||||
one. The route calls ``execute`` exactly once per query so this maps
|
||||
cleanly onto the assertion order of the test.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._results: list[_Result] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
# Count execute() calls to assert "no N+1 reverts".
|
||||
self.execute_call_count = 0
|
||||
|
||||
def queue(self, *results: _Result) -> None:
|
||||
self._results.extend(results)
|
||||
|
||||
async def execute(self, _stmt: Any) -> _Result:
|
||||
self.execute_call_count += 1
|
||||
if not self._results:
|
||||
return _Result(rows=[])
|
||||
return self._results.pop(0)
|
||||
|
||||
def begin_nested(self) -> _FakeNestedCtx:
|
||||
return _FakeNestedCtx()
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
def _enabled_flags() -> AgentFeatureFlags:
|
||||
return AgentFeatureFlags(
|
||||
disable_new_agent_stack=False,
|
||||
enable_action_log=True,
|
||||
enable_revert_route=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_get_flags():
|
||||
def _patch(flags: AgentFeatureFlags):
|
||||
return patch(
|
||||
"app.routes.agent_revert_route.get_flags",
|
||||
return_value=flags,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestFlagGuard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_503_when_revert_route_disabled(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
flags = AgentFeatureFlags(
|
||||
disable_new_agent_stack=False,
|
||||
enable_action_log=True,
|
||||
enable_revert_route=False,
|
||||
)
|
||||
session = _FakeSession()
|
||||
with patch_get_flags(flags), pytest.raises(Exception) as exc:
|
||||
await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="42:1700000000000",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert getattr(exc.value, "status_code", None) == 503
|
||||
|
||||
|
||||
class TestRevertTurnDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=[])) # rows query returns nothing
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-empty",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.total == 0
|
||||
assert response.results == []
|
||||
assert session.committed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_walks_rows_in_reverse_and_reverts_each(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=10, tool_name="rm"),
|
||||
_FakeAction(id=9, tool_name="write_file"),
|
||||
_FakeAction(id=8, tool_name="mkdir"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched ``_was_already_reverted_batch`` probe replaces
|
||||
# the previous N per-row SELECTs.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"reverted-{action.id}",
|
||||
new_action_id=100 + action.id,
|
||||
)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-3",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
assert response.total == 3
|
||||
assert response.reverted == 3
|
||||
assert [r.action_id for r in response.results] == [10, 9, 8]
|
||||
assert all(r.status == "reverted" for r in response.results)
|
||||
assert response.results[0].new_action_id == 110
|
||||
# Only TWO ``execute`` calls regardless of the row count: one
|
||||
# for the rows query, one for the batched
|
||||
# ``_was_already_reverted_batch`` probe. Regression guard
|
||||
# against re-introducing the per-row N+1 lookup.
|
||||
assert session.execute_call_count == 2, (
|
||||
"revert-turn loop must batch idempotency probes; got "
|
||||
f"{session.execute_call_count} execute() calls (expected 2)."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_reverted_rows_are_marked_idempotent(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=5, tool_name="edit_file")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch probe returns ``[(revert_id, original_id)]``.
|
||||
session.queue(_Result(rows=[(42, 5)]))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-i",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.already_reverted == 1
|
||||
assert response.results[0].status == "already_reverted"
|
||||
assert response.results[0].new_action_id == 42
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revert_action_skips_existing_revert_rows(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-rev",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.results[0].status == "skipped"
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_success_when_some_rows_not_reversible(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=2, tool_name="send_email"),
|
||||
_FakeAction(id=1, tool_name="edit_file"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched idempotency probe.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.tool_name == "send_email":
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="connector revert not yet implemented",
|
||||
)
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-mix",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.reverted == 1
|
||||
assert response.not_reversible == 1
|
||||
statuses = sorted(r.status for r in response.results)
|
||||
assert statuses == ["not_reversible", "reverted"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_marks_row_failed_not_batch(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [
|
||||
_FakeAction(id=20, tool_name="edit_file"),
|
||||
_FakeAction(id=21, tool_name="edit_file"),
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched idempotency probe.
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.id == 20:
|
||||
raise RuntimeError("disk on fire")
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=999)
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-fail",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.failed == 1
|
||||
assert response.reverted == 1
|
||||
bad = next(r for r in response.results if r.action_id == 20)
|
||||
assert bad.status == "failed"
|
||||
assert "disk on fire" in (bad.error or "")
|
||||
good = next(r for r in response.results if r.action_id == 21)
|
||||
assert good.status == "reverted"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied_when_other_user_owns_action(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch idempotency probe (no prior reverts).
|
||||
session.queue(_Result(rows=[]))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-perm",
|
||||
session=session,
|
||||
user=_FakeUser(id="not-owner"),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.results[0].status == "permission_denied"
|
||||
# ``permission_denied`` has its own dedicated counter so the
|
||||
# response invariant ``total == sum(counters)`` always holds
|
||||
# without overloading ``not_reversible`` (which historically
|
||||
# absorbed this case and confused frontend toasts).
|
||||
assert response.permission_denied == 1
|
||||
assert response.not_reversible == 0
|
||||
revert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_invariant_holds_across_mixed_outcomes(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
"""Every row is accounted for in EXACTLY ONE counter.
|
||||
|
||||
Mixes one of every supported outcome (reverted, already_reverted,
|
||||
not_reversible, permission_denied, failed, skipped) and asserts
|
||||
that the sum of counters equals ``response.total``.
|
||||
"""
|
||||
rows = [
|
||||
_FakeAction(id=10, tool_name="edit_file"), # ok
|
||||
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
|
||||
_FakeAction(id=8, tool_name="send_email"), # not_reversible
|
||||
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
|
||||
_FakeAction(id=6, tool_name="edit_file"), # failed
|
||||
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
|
||||
]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Single batched probe; only id=9 has a prior revert.
|
||||
# Schema: list[(revert_id, original_id)].
|
||||
session.queue(_Result(rows=[(42, 9)]))
|
||||
|
||||
async def _fake_revert(_session, *, action, requester_user_id):
|
||||
if action.id == 10:
|
||||
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||
if action.id == 8:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="connector revert not yet implemented",
|
||||
)
|
||||
if action.id == 6:
|
||||
raise RuntimeError("boom")
|
||||
raise AssertionError(f"unexpected revert call for {action.id}")
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route,
|
||||
"revert_action",
|
||||
AsyncMock(side_effect=_fake_revert),
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-mixed-all",
|
||||
session=session,
|
||||
user=_FakeUser(), # only id=7 has a different user_id
|
||||
)
|
||||
|
||||
assert response.total == len(rows) == 6
|
||||
bucket_sum = (
|
||||
response.reverted
|
||||
+ response.already_reverted
|
||||
+ response.not_reversible
|
||||
+ response.permission_denied
|
||||
+ response.failed
|
||||
+ response.skipped
|
||||
)
|
||||
assert bucket_sum == response.total, (
|
||||
"Counter invariant broken: total "
|
||||
f"({response.total}) != sum of counters ({bucket_sum}). "
|
||||
f"Counters: reverted={response.reverted}, "
|
||||
f"already_reverted={response.already_reverted}, "
|
||||
f"not_reversible={response.not_reversible}, "
|
||||
f"permission_denied={response.permission_denied}, "
|
||||
f"failed={response.failed}, skipped={response.skipped}"
|
||||
)
|
||||
assert response.reverted == 1
|
||||
assert response.already_reverted == 1
|
||||
assert response.not_reversible == 1
|
||||
assert response.permission_denied == 1
|
||||
assert response.failed == 1
|
||||
assert response.skipped == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integrity_error_translates_to_already_reverted(
|
||||
self, patch_get_flags
|
||||
) -> None:
|
||||
"""The partial unique index on ``reverse_of`` raises
|
||||
``IntegrityError`` when a concurrent revert wins the race against
|
||||
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
|
||||
recover by re-querying for the winning revert id and returning
|
||||
``status="already_reverted"`` (not ``"failed"``) so racing
|
||||
clients see consistent idempotent semantics.
|
||||
"""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
rows = [_FakeAction(id=33, tool_name="edit_file")]
|
||||
session = _FakeSession()
|
||||
session.queue(_Result(rows=rows))
|
||||
# Batch pre-flight probe: nothing yet (we'll race).
|
||||
session.queue(_Result(rows=[]))
|
||||
# Post-IntegrityError fallback uses the SCALAR
|
||||
# ``_was_already_reverted`` (single-id lookup) so it pulls
|
||||
# ``[777]`` via ``.scalars().first()``.
|
||||
session.queue(_Result(rows=[777]))
|
||||
|
||||
async def _racing_revert(_session, *, action, requester_user_id):
|
||||
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
|
||||
|
||||
with (
|
||||
patch_get_flags(_enabled_flags()),
|
||||
patch.object(
|
||||
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||
),
|
||||
patch.object(
|
||||
agent_revert_route,
|
||||
"revert_action",
|
||||
AsyncMock(side_effect=_racing_revert),
|
||||
),
|
||||
):
|
||||
response = await agent_revert_route.revert_agent_turn(
|
||||
thread_id=1,
|
||||
chat_turn_id="ct-race",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
)
|
||||
|
||||
assert response.failed == 0, (
|
||||
"IntegrityError must NOT surface as a failed row; the unique "
|
||||
"index is the durable expression of idempotency."
|
||||
)
|
||||
assert response.already_reverted == 1
|
||||
assert response.results[0].status == "already_reverted"
|
||||
assert response.results[0].new_action_id == 777
|
||||
|
|
@ -0,0 +1,370 @@
|
|||
"""Unit tests for the filesystem-tool branches of ``revert_service``.
|
||||
|
||||
Covers:
|
||||
|
||||
* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document
|
||||
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
|
||||
prefix-based dispatch).
|
||||
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
|
||||
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
|
||||
when ``metadata_before["virtual_path"]`` is missing.
|
||||
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
|
||||
document.
|
||||
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
|
||||
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
|
||||
when the folder gained children.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.services import revert_service
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
revert_service,
|
||||
"embed_texts",
|
||||
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||
self._rows = rows or []
|
||||
self._scalar = scalar
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return list(self._rows)
|
||||
|
||||
def scalar_one_or_none(self) -> Any:
|
||||
return self._scalar
|
||||
|
||||
def scalars(self) -> Any:
|
||||
return _FakeScalarsProxy(self._rows)
|
||||
|
||||
|
||||
class _FakeScalarsProxy:
|
||||
def __init__(self, rows: list[Any]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = AsyncMock()
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.flush = AsyncMock()
|
||||
# session.get(Model, pk) lookup
|
||||
self.get = AsyncMock(return_value=None)
|
||||
|
||||
async def _flush_assigning_ids() -> None:
|
||||
for obj in self.added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = 999
|
||||
|
||||
self.flush.side_effect = _flush_assigning_ids
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: list[Any]) -> None:
|
||||
self.added.extend(objs)
|
||||
|
||||
|
||||
def _action(*, tool_name: str, action_id: int = 7):
|
||||
return MagicMock(
|
||||
id=action_id,
|
||||
tool_name=tool_name,
|
||||
thread_id=1,
|
||||
search_space_id=2,
|
||||
user_id="user-1",
|
||||
reverse_descriptor=None,
|
||||
)
|
||||
|
||||
|
||||
def _doc_revision(
|
||||
*,
|
||||
document_id: int | None = None,
|
||||
content_before: str | None = "old content",
|
||||
title_before: str | None = "notes.md",
|
||||
folder_id_before: int | None = 5,
|
||||
chunks_before: list[dict[str, str]] | None = None,
|
||||
metadata_before: dict[str, str] | None = None,
|
||||
):
|
||||
revision = MagicMock()
|
||||
revision.id = 100
|
||||
revision.document_id = document_id
|
||||
revision.search_space_id = 2
|
||||
revision.content_before = content_before
|
||||
revision.title_before = title_before
|
||||
revision.folder_id_before = folder_id_before
|
||||
revision.chunks_before = chunks_before or []
|
||||
revision.metadata_before = metadata_before
|
||||
return revision
|
||||
|
||||
|
||||
def _folder_revision(
|
||||
*,
|
||||
folder_id: int | None = None,
|
||||
name_before: str | None = "team",
|
||||
parent_id_before: int | None = None,
|
||||
position_before: str | None = "a0",
|
||||
):
|
||||
revision = MagicMock()
|
||||
revision.id = 200
|
||||
revision.folder_id = folder_id
|
||||
revision.search_space_id = 2
|
||||
revision.name_before = name_before
|
||||
revision.parent_id_before = parent_id_before
|
||||
revision.position_before = position_before
|
||||
return revision
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exact-name dispatch regression guards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExactDispatch:
|
||||
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rmdir_does_not_misroute_to_document(self) -> None:
|
||||
# If dispatch used `startswith("rm")` we'd hit the document branch
|
||||
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
|
||||
session = _FakeSession()
|
||||
action = _action(tool_name="rmdir")
|
||||
# No folder revisions exist for this action.
|
||||
session.execute.return_value = _FakeResult(rows=[])
|
||||
outcome = await revert_service.revert_action(
|
||||
session, # type: ignore[arg-type]
|
||||
action=action,
|
||||
requester_user_id="user-1",
|
||||
)
|
||||
assert outcome.status == "not_reversible"
|
||||
assert "folder_revisions" in outcome.message
|
||||
|
||||
def test_dispatch_sets_split_doc_and_folder(self) -> None:
|
||||
# Static guards on the dispatch tables themselves so a future
|
||||
# refactor doesn't accidentally reintroduce the prefix bug.
|
||||
assert "rm" in revert_service._DOC_TOOLS
|
||||
assert "rmdir" in revert_service._FOLDER_TOOLS
|
||||
assert "rmdir" not in revert_service._DOC_TOOLS
|
||||
assert "rm" not in revert_service._FOLDER_TOOLS
|
||||
# ``move_file`` lives only in document tools (it's a doc rename).
|
||||
assert "move_file" in revert_service._DOC_TOOLS
|
||||
assert "move_file" not in revert_service._FOLDER_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rm revert (re-INSERT)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_inserts_document_with_chunks(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None, # row was hard-deleted
|
||||
content_before="hello world",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
|
||||
metadata_before={"virtual_path": "/documents/x.md"},
|
||||
)
|
||||
# No collision check hit and the resulting query returns nothing.
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
assert outcome.status == "ok"
|
||||
# New Document + 2 chunks must have been added.
|
||||
from app.db import Chunk, Document
|
||||
|
||||
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
|
||||
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
|
||||
assert len(added_docs) == 1
|
||||
assert added_docs[0].title == "x.md"
|
||||
assert len(added_chunks) == 2
|
||||
# Snapshot was repointed at the new doc id so a follow-up revert works.
|
||||
assert revision.document_id == added_docs[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
|
||||
self,
|
||||
) -> None:
|
||||
session = _FakeSession()
|
||||
# Snapshot with NO metadata_before — the fallback path must kick in.
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hello",
|
||||
title_before="cap.md",
|
||||
folder_id_before=42,
|
||||
chunks_before=[],
|
||||
metadata_before=None,
|
||||
)
|
||||
# session.get(Folder, 42) returns a folder with a name.
|
||||
folder = MagicMock()
|
||||
folder.name = "team"
|
||||
folder.parent_id = None
|
||||
# First .get is for the folder lookup in the path-derivation.
|
||||
session.get = AsyncMock(return_value=folder)
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_root_path_when_no_folder(
|
||||
self,
|
||||
) -> None:
|
||||
"""metadata_before is None and folder_id_before is None still
|
||||
resolves: title fallback yields ``/documents/<title>`` so revert
|
||||
proceeds at the root of the documents tree."""
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hello",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
metadata_before=None,
|
||||
)
|
||||
# No collision in the documents tree at /documents/x.md.
|
||||
session.execute.return_value = _FakeResult(scalar=None)
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=None,
|
||||
content_before="hi",
|
||||
title_before="x.md",
|
||||
folder_id_before=None,
|
||||
metadata_before={"virtual_path": "/documents/x.md"},
|
||||
)
|
||||
# SELECT for unique_identifier_hash collision hits an existing row.
|
||||
session.execute.return_value = _FakeResult(scalar=42)
|
||||
outcome = await revert_service._reinsert_document_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "tool_unavailable"
|
||||
assert "collide" in outcome.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file create revert (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWriteFileCreateRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_created_doc(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _doc_revision(
|
||||
document_id=99,
|
||||
content_before=None, # marker for "created in this action"
|
||||
title_before=None,
|
||||
)
|
||||
outcome = await revert_service._delete_created_document(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
# Exactly one DELETE was issued.
|
||||
assert session.execute.await_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rmdir revert (re-INSERT folder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRmdirRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_inserts_folder_from_snapshot(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(
|
||||
folder_id=None,
|
||||
name_before="team",
|
||||
parent_id_before=None,
|
||||
position_before="a0",
|
||||
)
|
||||
outcome = await revert_service._reinsert_folder_from_revision(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
from app.db import Folder
|
||||
|
||||
assert outcome.status == "ok"
|
||||
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
|
||||
assert len(added_folders) == 1
|
||||
assert added_folders[0].name == "team"
|
||||
assert revision.folder_id == added_folders[0].id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mkdir revert (DELETE folder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMkdirRevert:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_empty_folder(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(folder_id=42)
|
||||
# Both the doc-existence check and the child-folder check return None.
|
||||
session.execute.side_effect = [
|
||||
_FakeResult(scalar=None), # docs
|
||||
_FakeResult(scalar=None), # children
|
||||
_FakeResult(scalar=None), # delete (no return value)
|
||||
]
|
||||
outcome = await revert_service._delete_created_folder(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "ok"
|
||||
# 3 executes: docs check, children check, delete.
|
||||
assert session.execute.await_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
|
||||
session = _FakeSession()
|
||||
revision = _folder_revision(folder_id=42)
|
||||
# First check (docs) returns "row found".
|
||||
session.execute.return_value = _FakeResult(scalar=1)
|
||||
outcome = await revert_service._delete_created_folder(
|
||||
session, # type: ignore[arg-type]
|
||||
revision=revision,
|
||||
)
|
||||
assert outcome.status == "tool_unavailable"
|
||||
assert "no longer empty" in outcome.message
|
||||
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
|
||||
|
||||
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
||||
silently dropped every other shape (Anthropic typed-block lists,
|
||||
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
|
||||
a few providers). These regression tests pin those four shapes plus the
|
||||
defensive cases (``None`` chunk, mixed types, missing fields).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeChunk:
|
||||
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
|
||||
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class TestStringContent:
|
||||
def test_plain_string_content_extracts_as_text(self) -> None:
|
||||
chunk = _FakeChunk(content="hello world")
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "hello world"
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
|
||||
def test_empty_string_content_yields_empty_text(self) -> None:
|
||||
chunk = _FakeChunk(content="")
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == ""
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
|
||||
|
||||
class TestListContent:
|
||||
def test_list_of_text_blocks_concatenates(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "Hello world"
|
||||
assert out["reasoning"] == ""
|
||||
|
||||
def test_mixed_text_and_reasoning_blocks(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "Let me think... "},
|
||||
{"type": "reasoning", "text": "still thinking."},
|
||||
{"type": "text", "text": "The answer is 42."},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "The answer is 42."
|
||||
assert out["reasoning"] == "Let me think... still thinking."
|
||||
|
||||
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Calling tool..."},
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"id": "call_123",
|
||||
"name": "make_widget",
|
||||
"args": '{"color":"red"}',
|
||||
},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "Calling tool..."
|
||||
assert out["reasoning"] == ""
|
||||
assert len(out["tool_call_chunks"]) == 1
|
||||
assert out["tool_call_chunks"][0]["id"] == "call_123"
|
||||
assert out["tool_call_chunks"][0]["name"] == "make_widget"
|
||||
|
||||
def test_tool_use_blocks_also_extracted(self) -> None:
|
||||
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_xyz",
|
||||
"name": "search",
|
||||
},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["tool_call_chunks"] == [
|
||||
{"type": "tool_use", "id": "call_xyz", "name": "search"}
|
||||
]
|
||||
|
||||
def test_unknown_block_types_are_ignored(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "image_url", "url": "https://example.com/x.png"},
|
||||
{"type": "text", "text": "ok"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "ok"
|
||||
|
||||
def test_blocks_without_text_field_are_ignored(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{"type": "text"}, # no text/content key
|
||||
{"type": "text", "text": "kept"},
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "kept"
|
||||
|
||||
|
||||
class TestAdditionalKwargsReasoning:
|
||||
def test_reasoning_content_in_additional_kwargs(self) -> None:
|
||||
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
|
||||
chunk = _FakeChunk(
|
||||
content="visible answer",
|
||||
additional_kwargs={"reasoning_content": "internal monologue"},
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "visible answer"
|
||||
assert out["reasoning"] == "internal monologue"
|
||||
|
||||
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[{"type": "reasoning", "text": "from blocks. "}],
|
||||
additional_kwargs={"reasoning_content": "from kwargs."},
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["reasoning"] == "from blocks. from kwargs."
|
||||
|
||||
|
||||
class TestToolCallChunksAttribute:
|
||||
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
|
||||
self,
|
||||
) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content="streaming text",
|
||||
tool_call_chunks=[
|
||||
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
|
||||
],
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert out["text"] == "streaming text"
|
||||
assert len(out["tool_call_chunks"]) == 1
|
||||
assert out["tool_call_chunks"][0]["id"] == "tc-9"
|
||||
|
||||
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"id": "from-block",
|
||||
"name": "x",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
|
||||
assert ids == ["from-block", "from-attr"]
|
||||
|
||||
|
||||
class TestDefensive:
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_value",
|
||||
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
|
||||
)
|
||||
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
|
||||
out = _extract_chunk_parts(chunk_value)
|
||||
assert out["text"] == ""
|
||||
assert out["reasoning"] == ""
|
||||
assert out["tool_call_chunks"] == []
|
||||
|
||||
|
||||
class TestIdlessContinuationChunks:
|
||||
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
|
||||
tool call carries id+name; later chunks for the same call have
|
||||
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
|
||||
argument streaming relies on those idless continuation chunks
|
||||
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
|
||||
chunk-emission loop can still route them by ``index``.
|
||||
"""
|
||||
|
||||
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
|
||||
chunk = _FakeChunk(
|
||||
tool_call_chunks=[
|
||||
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||
]
|
||||
)
|
||||
out = _extract_chunk_parts(chunk)
|
||||
assert len(out["tool_call_chunks"]) == 1
|
||||
tcc = out["tool_call_chunks"][0]
|
||||
assert tcc.get("id") is None
|
||||
assert tcc.get("name") is None
|
||||
assert tcc.get("args") == '_path":"/x"}'
|
||||
assert tcc.get("index") == 0
|
||||
|
||||
def test_first_then_idless_sequence_preserves_index(self) -> None:
|
||||
"""Both chunks for the same call share an ``index`` key — the
|
||||
index-routing loop in ``stream_new_chat`` depends on it."""
|
||||
first = _FakeChunk(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||
]
|
||||
)
|
||||
cont = _FakeChunk(
|
||||
tool_call_chunks=[
|
||||
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||
]
|
||||
)
|
||||
out_first = _extract_chunk_parts(first)
|
||||
out_cont = _extract_chunk_parts(cont)
|
||||
assert out_first["tool_call_chunks"][0]["index"] == 0
|
||||
assert out_cont["tool_call_chunks"][0]["index"] == 0
|
||||
assert out_cont["tool_call_chunks"][0].get("id") is None
|
||||
|
|
@ -0,0 +1,527 @@
|
|||
"""Unit tests for live tool-call argument streaming.
|
||||
|
||||
Pins the wire format that ``_stream_agent_events`` emits when
|
||||
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
||||
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
||||
all keyed by the same LangChain ``tool_call.id``.
|
||||
|
||||
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
||||
``_stream_agent_events`` so we exercise them via the public wire output.
|
||||
|
||||
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
||||
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.stream_new_chat as stream_module
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_legacy_match_lc_id,
|
||||
_stream_agent_events,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeChunk:
|
||||
"""Minimal stand-in for ``AIMessageChunk``."""
|
||||
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeToolMessage:
|
||||
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
|
||||
|
||||
content: Any
|
||||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
class _FakeAgentState:
|
||||
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||
# check a no-op too.
|
||||
self.values: dict[str, Any] = {}
|
||||
self.tasks: list[Any] = []
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Replays a list of ``astream_events`` events."""
|
||||
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
self._events = events
|
||||
|
||||
async def astream_events( # type: ignore[no-untyped-def]
|
||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
del config, version # unused, contract-compatible
|
||||
for ev in self._events:
|
||||
yield ev
|
||||
|
||||
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
|
||||
# Called once after astream_events drains so the cloud-fallback
|
||||
# safety net can inspect staged filesystem work. The fake stays
|
||||
# empty so the safety net is a no-op.
|
||||
return _FakeAgentState()
|
||||
|
||||
|
||||
def _model_stream(
|
||||
*,
|
||||
text: str = "",
|
||||
reasoning: str = "",
|
||||
tool_call_chunks: list[dict[str, Any]] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return (
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"tags": tags or [],
|
||||
"data": {
|
||||
"chunk": _FakeChunk(
|
||||
content=text,
|
||||
tool_call_chunks=list(tool_call_chunks or []),
|
||||
)
|
||||
},
|
||||
# reasoning piggybacks via additional_kwargs path; if needed,
|
||||
# override content to a typed-block list. Most tests just check
|
||||
# tool_call_chunks routing so this is fine.
|
||||
}
|
||||
if not reasoning
|
||||
else {
|
||||
"event": "on_chat_model_stream",
|
||||
"tags": tags or [],
|
||||
"data": {
|
||||
"chunk": _FakeChunk(
|
||||
content=text,
|
||||
additional_kwargs={"reasoning_content": reasoning},
|
||||
tool_call_chunks=list(tool_call_chunks or []),
|
||||
)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _tool_start(
|
||||
*,
|
||||
name: str,
|
||||
run_id: str,
|
||||
input_payload: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"event": "on_tool_start",
|
||||
"name": name,
|
||||
"run_id": run_id,
|
||||
"data": {"input": input_payload or {}},
|
||||
}
|
||||
|
||||
|
||||
def _tool_end(
|
||||
*,
|
||||
name: str,
|
||||
run_id: str,
|
||||
tool_call_id: str | None = None,
|
||||
output: Any = "ok",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"event": "on_tool_end",
|
||||
"name": name,
|
||||
"run_id": run_id,
|
||||
"data": {
|
||||
"output": _FakeToolMessage(
|
||||
content=json.dumps(output) if not isinstance(output, str) else output,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
||||
)
|
||||
|
||||
|
||||
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||
SSE payloads (parsed JSON) it yielded.
|
||||
"""
|
||||
agent = _FakeAgent(events)
|
||||
service = VercelStreamingService()
|
||||
result = StreamResult()
|
||||
config = {"configurable": {"thread_id": "test-thread"}}
|
||||
sse_lines: list[str] = []
|
||||
async for sse in _stream_agent_events(
|
||||
agent, config, {}, service, result, step_prefix="thinking"
|
||||
):
|
||||
sse_lines.append(sse)
|
||||
|
||||
parsed: list[dict[str, Any]] = []
|
||||
for line in sse_lines:
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
body = line[len("data: ") :].rstrip("\n")
|
||||
if not body or body == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
parsed.append(json.loads(body))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return parsed
|
||||
|
||||
|
||||
def _types(payloads: list[dict[str, Any]]) -> list[str]:
|
||||
return [p.get("type", "?") for p in payloads]
|
||||
|
||||
|
||||
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
|
||||
return [p for p in payloads if p.get("type") == type_name]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLegacyMatch:
|
||||
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
|
||||
chunks: list[dict[str, Any]] = [
|
||||
{"id": "x1", "name": "ls"},
|
||||
{"id": "y1", "name": "write_file"},
|
||||
]
|
||||
runs: dict[str, str] = {}
|
||||
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
|
||||
assert result == "y1"
|
||||
assert chunks == [{"id": "x1", "name": "ls"}]
|
||||
assert runs == {"run-1": "y1"}
|
||||
|
||||
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
|
||||
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
|
||||
runs: dict[str, str] = {}
|
||||
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
|
||||
assert out == "anon"
|
||||
assert chunks == []
|
||||
|
||||
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
|
||||
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
|
||||
runs: dict[str, str] = {}
|
||||
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
|
||||
assert chunks == [{"id": None, "name": None}]
|
||||
assert runs == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parity_v2 wire format tests.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||
"""First chunk carries id+name; later idless chunks at the same
|
||||
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||
one ``tool-input-delta`` per chunk."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||
],
|
||||
),
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||
],
|
||||
),
|
||||
_tool_start(
|
||||
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||
),
|
||||
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
|
||||
payloads = await _drain(events)
|
||||
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
deltas = _of_type(payloads, "tool-input-delta")
|
||||
available = _of_type(payloads, "tool-input-available")
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"] == "lc-1"
|
||||
assert starts[0]["toolName"] == "write_file"
|
||||
assert starts[0]["langchainToolCallId"] == "lc-1"
|
||||
|
||||
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
|
||||
assert all(d["toolCallId"] == "lc-1" for d in deltas)
|
||||
|
||||
assert len(available) == 1
|
||||
assert available[0]["toolCallId"] == "lc-1"
|
||||
|
||||
assert len(output) == 1
|
||||
assert output[0]["toolCallId"] == "lc-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_interleaved_tool_calls_route_by_index(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""Two same-name calls with distinct indices keep their deltas
|
||||
routed to the right card."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
|
||||
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
|
||||
]
|
||||
),
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": None, "name": None, "args": "}", "index": 0},
|
||||
{"id": None, "name": None, "args": "}", "index": 1},
|
||||
]
|
||||
),
|
||||
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
|
||||
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
|
||||
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
|
||||
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
|
||||
]
|
||||
|
||||
payloads = await _drain(events)
|
||||
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
deltas = _of_type(payloads, "tool-input-delta")
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
|
||||
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||
|
||||
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
|
||||
for d in deltas:
|
||||
by_id[d["toolCallId"]].append(d["inputTextDelta"])
|
||||
assert by_id["lc-A"] == ['{"a":1', "}"]
|
||||
assert by_id["lc-B"] == ['{"b":2', "}"]
|
||||
|
||||
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
|
||||
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
relevant = [
|
||||
p
|
||||
for p in payloads
|
||||
if p.get("type")
|
||||
in {"tool-input-start", "tool-input-available", "tool-output-available"}
|
||||
]
|
||||
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"] == "lc-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_text_closes_before_early_tool_input_start(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||
events = [
|
||||
_model_stream(text="Working on it"),
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
types = _types(await _drain(events))
|
||||
text_end_idx = types.index("text-end")
|
||||
start_idx = types.index("tool-input-start")
|
||||
assert text_end_idx < start_idx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||
events = [
|
||||
_model_stream(
|
||||
text="I'll update it",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"id": "lc-1",
|
||||
"name": "write_file",
|
||||
"args": '{"file_path":"/x"}',
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
),
|
||||
_tool_start(
|
||||
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||
),
|
||||
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
types = _types(await _drain(events))
|
||||
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
|
||||
assert types.index("text-start") < types.index("text-delta")
|
||||
assert types.index("text-delta") < types.index("text-end")
|
||||
assert types.index("text-end") < types.index("tool-input-start")
|
||||
assert types.index("tool-input-start") < types.index("tool-input-delta")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parity_v2_off_preserves_legacy_shape(
|
||||
parity_v2_off: None,
|
||||
) -> None:
|
||||
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
||||
is ``call_<run_id>`` (NOT the lc id)."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
||||
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
|
||||
assert _of_type(payloads, "tool-input-delta") == []
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-A")
|
||||
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
||||
# legacy mode (the start event fires before the ToolMessage is
|
||||
# available, so we can't extract the authoritative LangChain id yet).
|
||||
assert "langchainToolCallId" not in starts[0]
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
assert output[0]["toolCallId"].startswith("call_run-A")
|
||||
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
||||
# in legacy mode: the chat tool card uses it to backfill the
|
||||
# LangChain id and join against the ``data-action-log`` SSE event
|
||||
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
||||
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
||||
# which is populated regardless of feature-flag state.
|
||||
assert output[0]["langchainToolCallId"] == "lc-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_append_prevents_stale_id_reuse(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||
must stay empty for indexed-registered chunks)."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
|
||||
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
|
||||
]
|
||||
),
|
||||
_tool_start(name="write_file", run_id="run-1", input_payload={}),
|
||||
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
|
||||
_tool_start(name="write_file", run_id="run-2", input_payload={}),
|
||||
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
# Two distinct lc ids, each its own card.
|
||||
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||
# Each tool-output-available landed on its respective card.
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_waits_for_both_id_and_name(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
|
||||
),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
assert _of_type(payloads, "tool-input-start") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unmatched_fallback_still_attaches_lc_id(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
||||
legacy fallback path must still emit ``tool-input-start`` with the
|
||||
matching ``langchainToolCallId``."""
|
||||
events = [
|
||||
# No index on the chunk → not registered into index_to_meta;
|
||||
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||
# match path can pop it at on_tool_start.
|
||||
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
|
||||
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
|
||||
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||
66
surfsense_web/app/(home)/blog/[slug]/loading.tsx
Normal file
66
surfsense_web/app/(home)/blog/[slug]/loading.tsx
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function BlogPostLoading() {
|
||||
return (
|
||||
<div className="min-h-screen relative pt-20">
|
||||
<div className="max-w-3xl mx-auto px-6 lg:px-10 pt-10 pb-20">
|
||||
{/* Breadcrumb */}
|
||||
<div className="flex items-center gap-2 mb-8">
|
||||
<Skeleton className="h-4 w-10" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-10" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-40" />
|
||||
</div>
|
||||
|
||||
{/* Tags */}
|
||||
<div className="flex flex-wrap gap-2 mb-4">
|
||||
<Skeleton className="h-6 w-16 rounded-full" />
|
||||
<Skeleton className="h-6 w-20 rounded-full" />
|
||||
</div>
|
||||
|
||||
{/* Title */}
|
||||
<div className="space-y-3 mb-6">
|
||||
<Skeleton className="h-10 w-full" />
|
||||
<Skeleton className="h-10 w-4/5" />
|
||||
</div>
|
||||
|
||||
{/* Description */}
|
||||
<Skeleton className="h-5 w-full mb-2" />
|
||||
<Skeleton className="h-5 w-3/4 mb-8" />
|
||||
|
||||
{/* Author + date */}
|
||||
<div className="flex items-center gap-3 mb-10">
|
||||
<Skeleton className="h-10 w-10 rounded-full" />
|
||||
<div className="space-y-1.5">
|
||||
<Skeleton className="h-4 w-32" />
|
||||
<Skeleton className="h-3 w-24" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Cover image */}
|
||||
<Skeleton className="w-full aspect-video rounded-xl mb-10" />
|
||||
|
||||
{/* Article body paragraphs */}
|
||||
{Array.from({ length: 5 }).map((_, i) => (
|
||||
<div key={i} className="space-y-2 mb-6">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-4/5" />
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Sub-heading */}
|
||||
<Skeleton className="h-7 w-56 mt-8 mb-4" />
|
||||
|
||||
{Array.from({ length: 3 }).map((_, i) => (
|
||||
<div key={i} className="space-y-2 mb-6">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-11/12" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
import { format } from "date-fns";
|
||||
import FuzzySearch from "fuzzy-search";
|
||||
import Link from "next/link";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { Container } from "@/components/container";
|
||||
import type { BlogEntry } from "./page";
|
||||
|
||||
|
|
@ -127,17 +127,13 @@ function MagazineSearchGrid({
|
|||
[allBlogs]
|
||||
);
|
||||
|
||||
const [results, setResults] = useState(allBlogs);
|
||||
useEffect(() => {
|
||||
setResults(searcher.search(search));
|
||||
}, [search, searcher]);
|
||||
|
||||
const gridItems = useMemo(() => {
|
||||
const results = search.trim() ? searcher.search(search) : allBlogs;
|
||||
if (search.trim()) {
|
||||
return results;
|
||||
}
|
||||
return results.filter((b) => b.slug !== featuredSlug);
|
||||
}, [results, search, featuredSlug]);
|
||||
}, [search, searcher, allBlogs, featuredSlug]);
|
||||
|
||||
return (
|
||||
<section aria-labelledby="archive-heading">
|
||||
|
|
|
|||
50
surfsense_web/app/(home)/blog/loading.tsx
Normal file
50
surfsense_web/app/(home)/blog/loading.tsx
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function BlogIndexLoading() {
|
||||
return (
|
||||
<div className="relative overflow-hidden bg-neutral-50 px-4 pt-20 md:px-8 dark:bg-neutral-950">
|
||||
<div className="mx-auto max-w-6xl pt-12 pb-24 md:pt-20">
|
||||
{/* Header */}
|
||||
<div className="mb-10 md:mb-14">
|
||||
<Skeleton className="h-10 w-24 rounded-md" />
|
||||
</div>
|
||||
|
||||
{/* Featured post skeleton */}
|
||||
<div className="mb-14 overflow-hidden rounded-3xl border border-neutral-200/80 dark:border-neutral-800">
|
||||
<Skeleton className="aspect-[2.4/1] min-h-[220px] w-full rounded-none" />
|
||||
<div className="p-6 md:p-8 space-y-3">
|
||||
<Skeleton className="h-5 w-24 rounded-full" />
|
||||
<Skeleton className="h-8 w-3/4" />
|
||||
<Skeleton className="h-4 w-full max-w-lg" />
|
||||
<div className="flex items-center gap-3 pt-2">
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<Skeleton className="h-4 w-28" />
|
||||
<Skeleton className="h-4 w-20" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Search bar skeleton */}
|
||||
<div className="mb-10">
|
||||
<Skeleton className="h-11 w-full max-w-md rounded-full" />
|
||||
</div>
|
||||
|
||||
{/* Grid of article cards */}
|
||||
<div className="grid gap-8 md:grid-cols-2 lg:grid-cols-3">
|
||||
{Array.from({ length: 6 }).map((_, i) => (
|
||||
<div key={i} className="space-y-3">
|
||||
<Skeleton className="aspect-video w-full rounded-2xl" />
|
||||
<Skeleton className="h-5 w-3/4" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-5/6" />
|
||||
<div className="flex items-center gap-2 pt-1">
|
||||
<Skeleton className="h-6 w-6 rounded-full" />
|
||||
<Skeleton className="h-4 w-24" />
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
63
surfsense_web/app/(home)/changelog/loading.tsx
Normal file
63
surfsense_web/app/(home)/changelog/loading.tsx
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function ChangelogLoading() {
|
||||
return (
|
||||
<div className="min-h-screen relative pt-20">
|
||||
{/* Header */}
|
||||
<div className="border-b border-border/50">
|
||||
<div className="max-w-5xl mx-auto relative">
|
||||
<div className="p-6 flex items-center justify-between">
|
||||
<div>
|
||||
{/* Breadcrumb */}
|
||||
<div className="flex items-center gap-2 mb-4">
|
||||
<Skeleton className="h-4 w-10" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-20" />
|
||||
</div>
|
||||
<Skeleton className="h-10 w-48 mb-2" />
|
||||
<Skeleton className="h-4 w-80" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Timeline */}
|
||||
<div className="max-w-5xl mx-auto px-6 lg:px-10 pt-10 pb-20">
|
||||
<div className="relative">
|
||||
{Array.from({ length: 3 }).map((_, i) => (
|
||||
<div key={i} className="relative flex flex-col md:flex-row gap-y-6 mb-10">
|
||||
{/* Left: date + version */}
|
||||
<div className="md:w-48 flex-shrink-0">
|
||||
<Skeleton className="h-4 w-24 mb-3" />
|
||||
<Skeleton className="h-12 w-12 rounded-xl" />
|
||||
</div>
|
||||
|
||||
{/* Right: content */}
|
||||
<div className="flex-1 md:pl-8 relative pb-10">
|
||||
<div className="space-y-4">
|
||||
{/* Title */}
|
||||
<Skeleton className="h-7 w-2/3" />
|
||||
{/* Tags */}
|
||||
<div className="flex gap-2">
|
||||
<Skeleton className="h-6 w-16 rounded-full" />
|
||||
<Skeleton className="h-6 w-20 rounded-full" />
|
||||
</div>
|
||||
{/* Body paragraphs */}
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-5/6" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
65
surfsense_web/app/(home)/free/[model_slug]/loading.tsx
Normal file
65
surfsense_web/app/(home)/free/[model_slug]/loading.tsx
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function FreeModelLoading() {
|
||||
return (
|
||||
<>
|
||||
{/* Chat area skeleton - fills viewport */}
|
||||
<div className="h-full flex flex-col">
|
||||
{/* Chat header */}
|
||||
<div className="flex items-center gap-3 border-b px-4 py-3">
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<Skeleton className="h-5 w-40" />
|
||||
</div>
|
||||
|
||||
{/* Chat messages area */}
|
||||
<div className="flex-1 flex flex-col justify-end gap-4 px-4 py-6">
|
||||
<div className="flex justify-end">
|
||||
<Skeleton className="h-10 w-56 rounded-2xl" />
|
||||
</div>
|
||||
<div className="space-y-2 max-w-lg">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-4/5" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Input bar */}
|
||||
<div className="border-t px-4 py-3">
|
||||
<Skeleton className="h-12 w-full rounded-xl" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* SEO section skeleton */}
|
||||
<div className="border-t bg-background">
|
||||
<div className="container mx-auto px-4 py-10 max-w-3xl">
|
||||
{/* Breadcrumb */}
|
||||
<div className="flex items-center gap-2 mb-6">
|
||||
<Skeleton className="h-4 w-10" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-24" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-32" />
|
||||
</div>
|
||||
|
||||
<Skeleton className="h-7 w-3/4 mb-2" />
|
||||
<Skeleton className="h-4 w-full mb-1" />
|
||||
<Skeleton className="h-4 w-2/3 mb-8" />
|
||||
|
||||
<div className="my-8 h-px bg-border" />
|
||||
|
||||
{/* FAQ skeleton */}
|
||||
<Skeleton className="h-6 w-64 mb-4" />
|
||||
<div className="flex flex-col gap-3">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<div key={i} className="rounded-lg border bg-card p-4 space-y-2">
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<Skeleton className="h-3 w-full" />
|
||||
<Skeleton className="h-3 w-5/6" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
60
surfsense_web/app/(home)/free/loading.tsx
Normal file
60
surfsense_web/app/(home)/free/loading.tsx
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function FreeChatLoading() {
|
||||
return (
|
||||
<div className="min-h-screen pt-20">
|
||||
<article className="container mx-auto px-4 pb-20">
|
||||
{/* Breadcrumb */}
|
||||
<div className="flex items-center gap-2 mb-8">
|
||||
<Skeleton className="h-4 w-10" />
|
||||
<Skeleton className="h-4 w-3" />
|
||||
<Skeleton className="h-4 w-24" />
|
||||
</div>
|
||||
|
||||
{/* Hero section */}
|
||||
<section className="mt-8 text-center max-w-3xl mx-auto space-y-4">
|
||||
<Skeleton className="h-12 w-3/4 mx-auto" />
|
||||
<Skeleton className="h-12 w-2/3 mx-auto" />
|
||||
<Skeleton className="h-5 w-full max-w-lg mx-auto" />
|
||||
<Skeleton className="h-5 w-4/5 max-w-lg mx-auto" />
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 mt-6">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<Skeleton key={i} className="h-8 w-28 rounded-full" />
|
||||
))}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<div className="my-12 max-w-4xl mx-auto h-px bg-border" />
|
||||
|
||||
{/* Model table */}
|
||||
<section className="max-w-4xl mx-auto">
|
||||
<Skeleton className="h-7 w-64 mb-2" />
|
||||
<Skeleton className="h-4 w-80 mb-6" />
|
||||
|
||||
<div className="overflow-hidden rounded-lg border">
|
||||
{/* Table header */}
|
||||
<div className="flex gap-4 px-4 py-3 bg-muted/50 border-b">
|
||||
<Skeleton className="h-4 w-[45%]" />
|
||||
<Skeleton className="h-4 w-24" />
|
||||
<Skeleton className="h-4 w-16" />
|
||||
<Skeleton className="h-4 w-20" />
|
||||
</div>
|
||||
|
||||
{/* Table rows */}
|
||||
{Array.from({ length: 8 }).map((_, i) => (
|
||||
<div key={i} className="flex items-center gap-4 px-4 py-3 border-b last:border-0">
|
||||
<div className="flex-1 space-y-1.5">
|
||||
<Skeleton className="h-4 w-40" />
|
||||
<Skeleton className="h-3 w-24" />
|
||||
</div>
|
||||
<Skeleton className="h-4 w-24" />
|
||||
<Skeleton className="h-6 w-14 rounded-full" />
|
||||
<Skeleton className="h-8 w-20 rounded-md" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</section>
|
||||
</article>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
55
surfsense_web/app/docs/[[...slug]]/loading.tsx
Normal file
55
surfsense_web/app/docs/[[...slug]]/loading.tsx
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function DocsLoading() {
|
||||
return (
|
||||
<div className="flex flex-1 flex-col gap-4 p-6 max-w-4xl mx-auto w-full">
|
||||
{/* Title */}
|
||||
<Skeleton className="h-9 w-64" />
|
||||
|
||||
{/* Description */}
|
||||
<Skeleton className="h-5 w-full max-w-md" />
|
||||
|
||||
<div className="mt-4 space-y-8">
|
||||
{/* Paragraph block 1 */}
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
|
||||
{/* Sub-heading */}
|
||||
<Skeleton className="h-7 w-48" />
|
||||
|
||||
{/* Paragraph block 2 */}
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-5/6" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-2/3" />
|
||||
</div>
|
||||
|
||||
{/* Code block placeholder */}
|
||||
<Skeleton className="h-28 w-full rounded-lg" />
|
||||
|
||||
{/* Sub-heading */}
|
||||
<Skeleton className="h-7 w-56" />
|
||||
|
||||
{/* List items */}
|
||||
<div className="space-y-3">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<div key={i} className="flex items-start gap-3">
|
||||
<Skeleton className="mt-1 h-3 w-3 shrink-0 rounded-full" />
|
||||
<Skeleton className="h-4 w-full max-w-lg" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Paragraph block 3 */}
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-4/5" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -10,21 +10,11 @@ import type { Document } from "@/contracts/types/document.types";
|
|||
export const mentionedDocumentsAtom = atom<Pick<Document, "id" | "title" | "document_type">[]>([]);
|
||||
|
||||
/**
|
||||
* Atom to store documents selected via the sidebar checkboxes / row clicks.
|
||||
* These are NOT inserted as chips – the composer shows a count badge instead.
|
||||
*/
|
||||
export const sidebarSelectedDocumentsAtom = atom<
|
||||
Pick<Document, "id" | "title" | "document_type">[]
|
||||
>([]);
|
||||
|
||||
/**
|
||||
* Derived read-only atom that merges @-mention chips and sidebar selections
|
||||
* into a single deduplicated set of document IDs for the backend.
|
||||
* Derived read-only atom that maps deduplicated mentioned docs
|
||||
* into backend payload fields.
|
||||
*/
|
||||
export const mentionedDocumentIdsAtom = atom((get) => {
|
||||
const chipDocs = get(mentionedDocumentsAtom);
|
||||
const sidebarDocs = get(sidebarSelectedDocumentsAtom);
|
||||
const allDocs = [...chipDocs, ...sidebarDocs];
|
||||
const allDocs = get(mentionedDocumentsAtom);
|
||||
const seen = new Set<string>();
|
||||
const deduped = allDocs.filter((d) => {
|
||||
const key = `${d.document_type}:${d.id}`;
|
||||
|
|
|
|||
|
|
@ -17,16 +17,12 @@ import {
|
|||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
||||
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons";
|
||||
import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||
import { AppError } from "@/lib/error";
|
||||
import { formatRelativeDate } from "@/lib/format-date";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
interface ActionLogItemProps {
|
||||
action: AgentAction;
|
||||
threadId: number;
|
||||
|
|
@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt
|
|||
const hasError = action.error !== null && action.error !== undefined;
|
||||
|
||||
const Icon = getToolIcon(action.tool_name);
|
||||
const displayName = formatToolName(action.tool_name);
|
||||
const displayName = getToolDisplayName(action.tool_name);
|
||||
|
||||
const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null;
|
||||
const truncatedArgs =
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { Activity, RefreshCcw } from "lucide-react";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import { useCallback } from "react";
|
||||
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
|
||||
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
|
@ -17,15 +17,12 @@ import {
|
|||
SheetTitle,
|
||||
} from "@/components/ui/sheet";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||
import {
|
||||
agentActionsQueryKey,
|
||||
useAgentActionsQuery,
|
||||
} from "@/hooks/use-agent-actions-query";
|
||||
import { ActionLogItem } from "./action-log-item";
|
||||
|
||||
const ACTION_LOG_PAGE_SIZE = 50;
|
||||
|
||||
function actionLogQueryKey(threadId: number) {
|
||||
return ["agent-actions", threadId] as const;
|
||||
}
|
||||
|
||||
function EmptyState() {
|
||||
return (
|
||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
|
||||
|
|
@ -85,25 +82,17 @@ export function ActionLogSheet() {
|
|||
|
||||
const threadId = state.threadId;
|
||||
|
||||
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({
|
||||
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"],
|
||||
queryFn: () =>
|
||||
agentActionsApiService.listForThread(threadId as number, {
|
||||
page: 0,
|
||||
pageSize: ACTION_LOG_PAGE_SIZE,
|
||||
}),
|
||||
enabled: state.open && threadId !== null && actionLogEnabled,
|
||||
staleTime: 15 * 1000,
|
||||
});
|
||||
const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
|
||||
threadId,
|
||||
{ enabled: state.open && actionLogEnabled }
|
||||
);
|
||||
|
||||
const handleRevertSuccess = useCallback(() => {
|
||||
if (threadId !== null) {
|
||||
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) });
|
||||
queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
|
||||
}
|
||||
}, [queryClient, threadId]);
|
||||
|
||||
const items = useMemo(() => data?.items ?? [], [data]);
|
||||
|
||||
return (
|
||||
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
|
||||
<SheetContent
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ import {
|
|||
useAllCitationMetadata,
|
||||
} from "@/components/assistant-ui/citation-metadata-context";
|
||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
|
||||
import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button";
|
||||
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
|
||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
|
@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => {
|
|||
<MessagePrimitive.Parts
|
||||
components={{
|
||||
Text: MarkdownText,
|
||||
Reasoning: ReasoningMessagePart,
|
||||
tools: {
|
||||
by_name: {
|
||||
generate_report: GenerateReportToolUI,
|
||||
|
|
@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => {
|
|||
const isLast = useAuiState((s) => s.message.isLast);
|
||||
const aui = useAui();
|
||||
const api = useElectronAPI();
|
||||
// Surface the persisted ``chat_turn_id`` so the per-turn revert
|
||||
// affordance can scope to just this message's actions. Streamed
|
||||
// turns get their id once the assistant message is hydrated/finalised.
|
||||
const chatTurnId = useAuiState(({ message }) => {
|
||||
const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined;
|
||||
return meta?.custom?.chatTurnId ?? null;
|
||||
});
|
||||
|
||||
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
|
||||
|
||||
|
|
@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => {
|
|||
</TooltipIconButton>
|
||||
)}
|
||||
<MessageInfoDropdown />
|
||||
<div className="ml-auto">
|
||||
<RevertTurnButton chatTurnId={chatTurnId} />
|
||||
</div>
|
||||
</ActionBarPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"use client";
|
||||
|
||||
/**
|
||||
* Confirmation dialog shown when the user edits a message that has
|
||||
* reversible downstream actions. Three buttons:
|
||||
*
|
||||
* • "Revert all & resubmit" — POST regenerate with revert_actions=true
|
||||
* • "Continue without revert" — POST regenerate with revert_actions=false
|
||||
* • "Cancel" — abort the edit entirely
|
||||
*
|
||||
* The dialog is auto-skipped when zero reversible downstream actions
|
||||
* exist (the caller checks first via ``downstreamReversibleCount``).
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
export type EditMessageDialogChoice = "revert" | "continue" | "cancel";
|
||||
|
||||
export interface EditMessageDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
downstreamReversibleCount: number;
|
||||
downstreamTotalCount: number;
|
||||
onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>;
|
||||
}
|
||||
|
||||
export function EditMessageDialog({
|
||||
open,
|
||||
onOpenChange,
|
||||
downstreamReversibleCount,
|
||||
downstreamTotalCount,
|
||||
onChoose,
|
||||
}: EditMessageDialogProps) {
|
||||
const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null);
|
||||
|
||||
// The parent's ``handleEditDialogChoice`` calls
|
||||
// ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``.
|
||||
// That collapses the dialog (Radix unmounts it) while ``onChoose``
|
||||
// is still awaiting the long-running stream. Without this guard,
|
||||
// the ``finally { setBusy(null) }`` below ran after unmount and
|
||||
// produced a "state update on unmounted component" dev warning.
|
||||
const mountedRef = useRef(true);
|
||||
useEffect(() => {
|
||||
mountedRef.current = true;
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handle = async (choice: EditMessageDialogChoice) => {
|
||||
setBusy(choice);
|
||||
try {
|
||||
await onChoose(choice);
|
||||
} finally {
|
||||
if (mountedRef.current) {
|
||||
setBusy(null);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<AlertDialog open={open} onOpenChange={onOpenChange}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Edit this message?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This edit drops {downstreamTotalCount} downstream message
|
||||
{downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "}
|
||||
action
|
||||
{downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can
|
||||
be rolled back. Pick how to handle them before regenerating.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}>
|
||||
{busy === "revert"
|
||||
? "Reverting & resubmitting…"
|
||||
: `Revert ${downstreamReversibleCount} action${
|
||||
downstreamReversibleCount === 1 ? "" : "s"
|
||||
} & resubmit`}
|
||||
</Button>
|
||||
<Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}>
|
||||
{busy === "continue" ? "Resubmitting…" : "Continue without reverting"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<AlertDialogFooter className="sm:justify-start">
|
||||
<AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}>
|
||||
Cancel
|
||||
</AlertDialogCancel>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
);
|
||||
}
|
||||
|
|
@ -11,25 +11,14 @@ import {
|
|||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { flushSync } from "react-dom";
|
||||
import { createRoot } from "react-dom/client";
|
||||
import { renderToStaticMarkup } from "react-dom/server";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Render a React element to an HTML string on the client without pulling
|
||||
// `react-dom/server` into the bundle. `createRoot` + `flushSync` use the
|
||||
// same `react-dom` package React itself imports, so this adds zero new
|
||||
// runtime weight.
|
||||
function renderElementToHTML(element: ReactElement): string {
|
||||
const container = document.createElement("div");
|
||||
const root = createRoot(container);
|
||||
flushSync(() => {
|
||||
root.render(element);
|
||||
});
|
||||
const html = container.innerHTML;
|
||||
root.unmount();
|
||||
return html;
|
||||
return renderToStaticMarkup(element);
|
||||
}
|
||||
|
||||
export interface MentionedDocument {
|
||||
|
|
@ -44,7 +33,10 @@ export interface InlineMentionEditorRef {
|
|||
setText: (text: string) => void;
|
||||
getText: () => string;
|
||||
getMentionedDocuments: () => MentionedDocument[];
|
||||
insertDocumentChip: (doc: Pick<Document, "id" | "title" | "document_type">) => void;
|
||||
insertDocumentChip: (
|
||||
doc: Pick<Document, "id" | "title" | "document_type">,
|
||||
options?: { removeTriggerText?: boolean }
|
||||
) => void;
|
||||
removeDocumentChip: (docId: number, docType?: string) => void;
|
||||
setDocumentChipStatus: (
|
||||
docId: number,
|
||||
|
|
@ -66,7 +58,6 @@ interface InlineMentionEditorProps {
|
|||
onKeyDown?: (e: React.KeyboardEvent) => void;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
initialDocuments?: MentionedDocument[];
|
||||
initialText?: string;
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +109,6 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
onKeyDown,
|
||||
disabled = false,
|
||||
className,
|
||||
initialDocuments = [],
|
||||
initialText,
|
||||
},
|
||||
ref
|
||||
|
|
@ -126,18 +116,49 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const editorRef = useRef<HTMLDivElement>(null);
|
||||
const [isEmpty, setIsEmpty] = useState(true);
|
||||
const [mentionedDocs, setMentionedDocs] = useState<Map<string, MentionedDocument>>(
|
||||
() => new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d]))
|
||||
() => new Map()
|
||||
);
|
||||
const isComposingRef = useRef(false);
|
||||
const lastSelectionRangeRef = useRef<Range | null>(null);
|
||||
const isRangeInsideEditor = useCallback((range: Range | null): range is Range => {
|
||||
if (!range || !editorRef.current) return false;
|
||||
return (
|
||||
editorRef.current.contains(range.startContainer) &&
|
||||
editorRef.current.contains(range.endContainer)
|
||||
);
|
||||
}, []);
|
||||
const isSelectionInsideEditor = useCallback(
|
||||
(selection: Selection | null): selection is Selection => {
|
||||
if (!selection || selection.rangeCount === 0 || !editorRef.current) return false;
|
||||
const range = selection.getRangeAt(0);
|
||||
return isRangeInsideEditor(range);
|
||||
},
|
||||
[isRangeInsideEditor]
|
||||
);
|
||||
|
||||
const rememberSelection = useCallback(() => {
|
||||
const selection = window.getSelection();
|
||||
if (!isSelectionInsideEditor(selection)) return;
|
||||
lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange();
|
||||
}, [isSelectionInsideEditor]);
|
||||
|
||||
const restoreRememberedSelection = useCallback((): Selection | null => {
|
||||
const selection = window.getSelection();
|
||||
if (!selection) return null;
|
||||
if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null;
|
||||
selection.removeAllRanges();
|
||||
selection.addRange(lastSelectionRangeRef.current.cloneRange());
|
||||
return selection;
|
||||
}, [isRangeInsideEditor]);
|
||||
|
||||
// Sync initial documents
|
||||
useEffect(() => {
|
||||
if (initialDocuments.length > 0) {
|
||||
setMentionedDocs(
|
||||
new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d]))
|
||||
);
|
||||
}
|
||||
}, [initialDocuments]);
|
||||
const handleSelectionChange = () => {
|
||||
if (document.activeElement !== editorRef.current) return;
|
||||
rememberSelection();
|
||||
};
|
||||
document.addEventListener("selectionchange", handleSelectionChange);
|
||||
return () => document.removeEventListener("selectionchange", handleSelectionChange);
|
||||
}, [rememberSelection]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!initialText || !editorRef.current) return;
|
||||
|
|
@ -145,7 +166,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
editorRef.current.appendChild(document.createElement("br"));
|
||||
editorRef.current.appendChild(document.createElement("br"));
|
||||
setIsEmpty(false);
|
||||
onChange?.(initialText, Array.from(mentionedDocs.values()));
|
||||
onChange?.(initialText, []);
|
||||
editorRef.current.focus();
|
||||
const sel = window.getSelection();
|
||||
const range = document.createRange();
|
||||
|
|
@ -157,7 +178,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
range.insertNode(anchor);
|
||||
anchor.scrollIntoView({ block: "end" });
|
||||
anchor.remove();
|
||||
}, [initialText]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
}, [initialText, onChange]);
|
||||
|
||||
// Focus at the end of the editor
|
||||
const focusAtEnd = useCallback(() => {
|
||||
|
|
@ -211,6 +232,19 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
return Array.from(mentionedDocs.values());
|
||||
}, [mentionedDocs]);
|
||||
|
||||
const syncEditorState = useCallback(
|
||||
(docsOverride?: Map<string, MentionedDocument>) => {
|
||||
const docs = docsOverride
|
||||
? Array.from(docsOverride.values())
|
||||
: Array.from(mentionedDocs.values());
|
||||
const text = getText();
|
||||
const empty = text.length === 0 && docs.length === 0;
|
||||
setIsEmpty(empty);
|
||||
onChange?.(text, docs);
|
||||
},
|
||||
[getText, mentionedDocs, onChange]
|
||||
);
|
||||
|
||||
// Create a chip element for a document
|
||||
const createChipElement = useCallback(
|
||||
(doc: MentionedDocument): HTMLSpanElement => {
|
||||
|
|
@ -246,10 +280,11 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
chip.remove();
|
||||
const docKey = `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
|
||||
const docKey = getMentionDocKey(doc);
|
||||
setMentionedDocs((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.delete(docKey);
|
||||
syncEditorState(next);
|
||||
return next;
|
||||
});
|
||||
onDocumentRemove?.(doc.id, doc.document_type);
|
||||
|
|
@ -294,13 +329,17 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
|
||||
return chip;
|
||||
},
|
||||
[focusAtEnd, onDocumentRemove]
|
||||
[focusAtEnd, onDocumentRemove, syncEditorState]
|
||||
);
|
||||
|
||||
// Insert a document chip at the current cursor position
|
||||
const insertDocumentChip = useCallback(
|
||||
(doc: Pick<Document, "id" | "title" | "document_type">) => {
|
||||
(
|
||||
doc: Pick<Document, "id" | "title" | "document_type">,
|
||||
options?: { removeTriggerText?: boolean }
|
||||
) => {
|
||||
if (!editorRef.current) return;
|
||||
const removeTriggerText = options?.removeTriggerText ?? true;
|
||||
|
||||
// Validate required fields for type safety
|
||||
if (typeof doc.id !== "number" || typeof doc.title !== "string") {
|
||||
|
|
@ -315,25 +354,51 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
};
|
||||
|
||||
// Add to mentioned docs map using unique key
|
||||
const docKey = `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
|
||||
const docKey = getMentionDocKey(doc);
|
||||
setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc));
|
||||
const nextDocs = new Map(mentionedDocs);
|
||||
nextDocs.set(docKey, mentionDoc);
|
||||
|
||||
// Find and remove the @query text
|
||||
const selection = window.getSelection();
|
||||
if (!selection || selection.rangeCount === 0) {
|
||||
// No selection, just append
|
||||
const hasActiveSelection = isSelectionInsideEditor(selection);
|
||||
const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection();
|
||||
if (
|
||||
!resolvedSelection ||
|
||||
resolvedSelection.rangeCount === 0 ||
|
||||
!isSelectionInsideEditor(resolvedSelection)
|
||||
) {
|
||||
// No valid in-editor selection: deterministically insert at end.
|
||||
editorRef.current.focus();
|
||||
const endSelection = window.getSelection();
|
||||
if (!endSelection) return;
|
||||
const endRange = document.createRange();
|
||||
endRange.selectNodeContents(editorRef.current);
|
||||
endRange.collapse(false);
|
||||
endSelection.removeAllRanges();
|
||||
endSelection.addRange(endRange);
|
||||
|
||||
const chip = createChipElement(mentionDoc);
|
||||
editorRef.current.appendChild(chip);
|
||||
editorRef.current.appendChild(document.createTextNode(" "));
|
||||
focusAtEnd();
|
||||
endRange.insertNode(chip);
|
||||
endRange.setStartAfter(chip);
|
||||
endRange.collapse(true);
|
||||
const space = document.createTextNode(" ");
|
||||
endRange.insertNode(space);
|
||||
endRange.setStartAfter(space);
|
||||
endRange.collapse(true);
|
||||
endSelection.removeAllRanges();
|
||||
endSelection.addRange(endRange);
|
||||
|
||||
syncEditorState(nextDocs);
|
||||
rememberSelection();
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the @ symbol before the cursor and remove it along with any query text
|
||||
const range = selection.getRangeAt(0);
|
||||
const range = resolvedSelection.getRangeAt(0);
|
||||
const textNode = range.startContainer;
|
||||
|
||||
if (textNode.nodeType === Node.TEXT_NODE) {
|
||||
if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) {
|
||||
const text = textNode.textContent || "";
|
||||
const cursorPos = range.startOffset;
|
||||
|
||||
|
|
@ -369,8 +434,9 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const newRange = document.createRange();
|
||||
newRange.setStart(afterNode, 1);
|
||||
newRange.collapse(true);
|
||||
selection.removeAllRanges();
|
||||
selection.addRange(newRange);
|
||||
resolvedSelection.removeAllRanges();
|
||||
resolvedSelection.addRange(newRange);
|
||||
rememberSelection();
|
||||
}
|
||||
} else {
|
||||
// No @ found, just insert at cursor
|
||||
|
|
@ -384,48 +450,56 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
range.insertNode(space);
|
||||
range.setStartAfter(space);
|
||||
range.collapse(true);
|
||||
resolvedSelection.removeAllRanges();
|
||||
resolvedSelection.addRange(range);
|
||||
rememberSelection();
|
||||
}
|
||||
} else {
|
||||
// Not in a text node, append to editor
|
||||
// Either explicit non-trigger insertion or no @query present.
|
||||
const chip = createChipElement(mentionDoc);
|
||||
editorRef.current.appendChild(chip);
|
||||
editorRef.current.appendChild(document.createTextNode(" "));
|
||||
focusAtEnd();
|
||||
range.insertNode(chip);
|
||||
range.setStartAfter(chip);
|
||||
range.collapse(true);
|
||||
const space = document.createTextNode(" ");
|
||||
range.insertNode(space);
|
||||
range.setStartAfter(space);
|
||||
range.collapse(true);
|
||||
resolvedSelection.removeAllRanges();
|
||||
resolvedSelection.addRange(range);
|
||||
rememberSelection();
|
||||
}
|
||||
|
||||
// Update empty state
|
||||
setIsEmpty(false);
|
||||
|
||||
// Trigger onChange
|
||||
if (onChange) {
|
||||
setTimeout(() => {
|
||||
onChange(getText(), getMentionedDocuments());
|
||||
}, 0);
|
||||
}
|
||||
syncEditorState(nextDocs);
|
||||
},
|
||||
[createChipElement, focusAtEnd, getText, getMentionedDocuments, onChange]
|
||||
[
|
||||
createChipElement,
|
||||
isSelectionInsideEditor,
|
||||
mentionedDocs,
|
||||
rememberSelection,
|
||||
restoreRememberedSelection,
|
||||
syncEditorState,
|
||||
]
|
||||
);
|
||||
|
||||
// Clear the editor
|
||||
const clear = useCallback(() => {
|
||||
if (editorRef.current) {
|
||||
editorRef.current.innerHTML = "";
|
||||
setIsEmpty(true);
|
||||
setMentionedDocs(new Map());
|
||||
const emptyDocs = new Map<string, MentionedDocument>();
|
||||
setMentionedDocs(emptyDocs);
|
||||
syncEditorState(emptyDocs);
|
||||
}
|
||||
}, []);
|
||||
}, [syncEditorState]);
|
||||
|
||||
// Replace editor content with plain text and place cursor at end
|
||||
const setText = useCallback(
|
||||
(text: string) => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.innerText = text;
|
||||
const empty = text.length === 0;
|
||||
setIsEmpty(empty);
|
||||
onChange?.(text, Array.from(mentionedDocs.values()));
|
||||
syncEditorState();
|
||||
focusAtEnd();
|
||||
},
|
||||
[focusAtEnd, onChange, mentionedDocs]
|
||||
[focusAtEnd, syncEditorState]
|
||||
);
|
||||
|
||||
const setDocumentChipStatus = useCallback(
|
||||
|
|
@ -473,7 +547,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const removeDocumentChip = useCallback(
|
||||
(docId: number, docType?: string) => {
|
||||
if (!editorRef.current) return;
|
||||
const chipKey = `${docType ?? "UNKNOWN"}:${docId}`;
|
||||
const chipKey = getMentionDocKey({ id: docId, document_type: docType });
|
||||
const chips = editorRef.current.querySelectorAll<HTMLSpanElement>(
|
||||
`span[${CHIP_DATA_ATTR}="true"]`
|
||||
);
|
||||
|
|
@ -486,14 +560,11 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
setMentionedDocs((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.delete(chipKey);
|
||||
syncEditorState(next);
|
||||
return next;
|
||||
});
|
||||
|
||||
const text = getText();
|
||||
const empty = text.length === 0 && mentionedDocs.size <= 1;
|
||||
setIsEmpty(empty);
|
||||
},
|
||||
[getText, mentionedDocs.size]
|
||||
[syncEditorState]
|
||||
);
|
||||
|
||||
// Expose methods via ref
|
||||
|
|
@ -594,6 +665,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
|
||||
// Notify parent of change
|
||||
onChange?.(text, Array.from(mentionedDocs.values()));
|
||||
rememberSelection();
|
||||
}, [
|
||||
getText,
|
||||
mentionedDocs,
|
||||
|
|
@ -602,6 +674,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
onMentionClose,
|
||||
onActionTrigger,
|
||||
onActionClose,
|
||||
rememberSelection,
|
||||
]);
|
||||
|
||||
// Handle keydown
|
||||
|
|
@ -639,10 +712,14 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const chipDocType = getChipDocType(prevSibling);
|
||||
if (chipId !== null) {
|
||||
prevSibling.remove();
|
||||
const chipKey = `${chipDocType}:${chipId}`;
|
||||
const chipKey = getMentionDocKey({
|
||||
id: chipId,
|
||||
document_type: chipDocType,
|
||||
});
|
||||
setMentionedDocs((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.delete(chipKey);
|
||||
syncEditorState(next);
|
||||
return next;
|
||||
});
|
||||
// Notify parent that a document was removed
|
||||
|
|
@ -676,10 +753,14 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const chipDocType = getChipDocType(prevChild);
|
||||
if (chipId !== null) {
|
||||
prevChild.remove();
|
||||
const chipKey = `${chipDocType}:${chipId}`;
|
||||
const chipKey = getMentionDocKey({
|
||||
id: chipId,
|
||||
document_type: chipDocType,
|
||||
});
|
||||
setMentionedDocs((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.delete(chipKey);
|
||||
syncEditorState(next);
|
||||
return next;
|
||||
});
|
||||
// Notify parent that a document was removed
|
||||
|
|
@ -691,7 +772,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
}
|
||||
}
|
||||
},
|
||||
[onKeyDown, onSubmit, onDocumentRemove, onMentionClose]
|
||||
[onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState]
|
||||
);
|
||||
|
||||
// Handle paste - strip formatting
|
||||
|
|
@ -713,7 +794,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
|
||||
return (
|
||||
<div className="relative w-full">
|
||||
{/** biome-ignore lint/a11y/useSemanticElements: <not important> */}
|
||||
{/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */}
|
||||
<div
|
||||
ref={editorRef}
|
||||
contentEditable={!disabled}
|
||||
|
|
@ -724,6 +805,9 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
onPaste={handlePaste}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
onKeyUp={rememberSelection}
|
||||
onMouseUp={rememberSelection}
|
||||
onBlur={rememberSelection}
|
||||
className={cn(
|
||||
"min-h-[24px] max-h-32 overflow-y-auto",
|
||||
"text-sm outline-none",
|
||||
|
|
@ -733,9 +817,6 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
)}
|
||||
style={{ wordBreak: "break-word" }}
|
||||
data-placeholder={placeholder}
|
||||
aria-label="Message input with inline mentions"
|
||||
role="textbox"
|
||||
aria-multiline="true"
|
||||
/>
|
||||
{/* Placeholder with fade animation on change */}
|
||||
{isEmpty && (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
"use client";
|
||||
|
||||
import type { ReasoningMessagePartComponent } from "@assistant-ui/react";
|
||||
import { ChevronRightIcon } from "lucide-react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
/**
|
||||
* Renders the structured `reasoning` part emitted by the backend's
|
||||
* stream-parity v2 path (A1).
|
||||
*
|
||||
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
|
||||
* - collapsed by default;
|
||||
* - auto-expanded while the part is still `running`;
|
||||
* - auto-collapsed once status flips to `complete`.
|
||||
*
|
||||
* The component is registered via the `Reasoning` slot on
|
||||
* `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the
|
||||
* exact ordinal position of the reasoning block in the message content
|
||||
* array (i.e. above the assistant text that follows it).
|
||||
*/
|
||||
export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => {
|
||||
const isRunning = status?.type === "running";
|
||||
const [isOpen, setIsOpen] = useState(() => isRunning);
|
||||
|
||||
useEffect(() => {
|
||||
if (isRunning) {
|
||||
setIsOpen(true);
|
||||
} else if (status?.type === "complete") {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}, [isRunning, status?.type]);
|
||||
|
||||
const headerLabel = useMemo(() => {
|
||||
if (isRunning) return "Thinking";
|
||||
if (status?.type === "incomplete") return "Thinking interrupted";
|
||||
return "Thought";
|
||||
}, [isRunning, status?.type]);
|
||||
|
||||
if (!text || text.length === 0) {
|
||||
if (!isRunning) return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
|
||||
<div className="rounded-lg">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setIsOpen((prev) => !prev)}
|
||||
className={cn(
|
||||
"flex w-full items-center gap-1.5 text-left text-sm transition-colors",
|
||||
"text-muted-foreground hover:text-foreground"
|
||||
)}
|
||||
>
|
||||
{isRunning ? (
|
||||
<TextShimmerLoader text={headerLabel} size="sm" />
|
||||
) : (
|
||||
<span>{headerLabel}</span>
|
||||
)}
|
||||
<ChevronRightIcon
|
||||
className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")}
|
||||
/>
|
||||
</button>
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"grid transition-[grid-template-rows] duration-300 ease-out",
|
||||
isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]"
|
||||
)}
|
||||
>
|
||||
<div className="overflow-hidden">
|
||||
<div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word">
|
||||
{text}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
213
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
213
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"use client";
|
||||
|
||||
/**
|
||||
* "Revert turn" button rendered at the bottom of every completed
|
||||
* assistant turn that has at least one reversible action.
|
||||
*
|
||||
* The button reads from the unified ``useAgentActionsQuery`` cache
|
||||
* (the SAME react-query cache the agent-actions sheet and the inline
|
||||
* Revert button consume) filtered by ``chat_turn_id``. It shows a
|
||||
* confirmation dialog summarising "N reversible / M total" and, on
|
||||
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||
*
|
||||
* The route returns a per-action result list and never collapses the
|
||||
* batch into a 4xx — so we render any failed/not_reversible rows inline
|
||||
* with their messages.
|
||||
*/
|
||||
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||
import {
|
||||
applyRevertTurnResultsToCache,
|
||||
useAgentActionsQuery,
|
||||
} from "@/hooks/use-agent-actions-query";
|
||||
import {
|
||||
agentActionsApiService,
|
||||
type RevertTurnActionResult,
|
||||
} from "@/lib/apis/agent-actions-api.service";
|
||||
import { AppError } from "@/lib/error";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface RevertTurnButtonProps {
|
||||
chatTurnId: string | null | undefined;
|
||||
}
|
||||
|
||||
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
||||
const session = useAtomValue(chatSessionStateAtom);
|
||||
const threadId = session?.threadId ?? null;
|
||||
const queryClient = useQueryClient();
|
||||
const { findByChatTurnId } = useAgentActionsQuery(threadId);
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||
const [resultsOpen, setResultsOpen] = useState(false);
|
||||
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
|
||||
|
||||
const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
|
||||
|
||||
const reversibleCount = useMemo(
|
||||
() =>
|
||||
actions.filter(
|
||||
(a) =>
|
||||
a.reversible &&
|
||||
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
|
||||
!a.is_revert_action &&
|
||||
(a.error === null || a.error === undefined)
|
||||
).length,
|
||||
[actions]
|
||||
);
|
||||
const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
|
||||
|
||||
if (!chatTurnId) return null;
|
||||
if (reversibleCount === 0) return null;
|
||||
if (!threadId) return null;
|
||||
|
||||
const handleRevertTurn = async () => {
|
||||
setIsReverting(true);
|
||||
try {
|
||||
const response = await agentActionsApiService.revertTurn(threadId, chatTurnId);
|
||||
setResults(response.results);
|
||||
const revertedEntries = response.results
|
||||
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
|
||||
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
|
||||
if (revertedEntries.length > 0) {
|
||||
applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
|
||||
}
|
||||
if (response.status === "ok") {
|
||||
toast.success(
|
||||
response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.`
|
||||
);
|
||||
} else {
|
||||
// Every "not undone" bucket counts as a failure for the
|
||||
// user-facing summary. ``skipped`` rows are batch
|
||||
// artefacts (revert rows themselves) and intentionally
|
||||
// excluded from the failure tally.
|
||||
const failureCount =
|
||||
response.failed + response.not_reversible + (response.permission_denied ?? 0);
|
||||
toast.warning(
|
||||
`Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.`
|
||||
);
|
||||
setResultsOpen(true);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof AppError && err.status === 503) {
|
||||
return;
|
||||
}
|
||||
const message =
|
||||
err instanceof AppError
|
||||
? err.message
|
||||
: err instanceof Error
|
||||
? err.message
|
||||
: "Failed to revert turn.";
|
||||
toast.error(message);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
setConfirmOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
className="text-muted-foreground hover:text-foreground gap-1.5"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setConfirmOpen(true);
|
||||
}}
|
||||
>
|
||||
<RotateCcw className="size-3.5" />
|
||||
<span>Revert turn</span>
|
||||
<span className="text-xs tabular-nums opacity-70">
|
||||
{reversibleCount}/{totalCount}
|
||||
</span>
|
||||
</Button>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Revert this turn?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This will undo {reversibleCount} of {totalCount} action
|
||||
{totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and
|
||||
any read-only actions are preserved. Some rows may not be reversible — partial success
|
||||
is normal.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
handleRevertTurn();
|
||||
}}
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? "Reverting…" : "Revert turn"}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
|
||||
<AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Revert results</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
Some actions could not be reverted. Review per-row outcomes below.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<ul className="max-h-72 overflow-y-auto space-y-2 text-sm">
|
||||
{results.map((r) => (
|
||||
<RevertResultRow key={r.action_id} result={r} />
|
||||
))}
|
||||
</ul>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function RevertResultRow({ result }: { result: RevertTurnActionResult }) {
|
||||
const isOk = result.status === "reverted" || result.status === "already_reverted";
|
||||
const Icon = isOk ? CheckIcon : XCircleIcon;
|
||||
return (
|
||||
<li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2">
|
||||
<Icon
|
||||
className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")}
|
||||
/>
|
||||
<div className="min-w-0 flex-1">
|
||||
<p className="font-medium truncate">
|
||||
{getToolDisplayName(result.tool_name)}{" "}
|
||||
<span className="ml-1 text-xs text-muted-foreground">
|
||||
{result.status.replace(/_/g, " ")}
|
||||
</span>
|
||||
</p>
|
||||
{(result.message || result.error) && (
|
||||
<p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p>
|
||||
)}
|
||||
</div>
|
||||
</li>
|
||||
);
|
||||
}
|
||||
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"use client";
|
||||
|
||||
import { makeAssistantDataUI } from "@assistant-ui/react";
|
||||
|
||||
/**
|
||||
* Renders a thin horizontal divider between model steps within a single
|
||||
* assistant turn. The data part is pushed by `addStepSeparator` in
|
||||
* `streaming-state.ts` whenever a `start-step` SSE event arrives after
|
||||
* the message already has non-step content.
|
||||
*
|
||||
* Today the backend emits one `start-step` / `finish-step` pair per turn,
|
||||
* so most messages won't contain a separator. The renderer is wired up so
|
||||
* the planned per-model-step refactor (A2 follow-up) can light up without
|
||||
* touching the persistence path.
|
||||
*/
|
||||
function StepSeparatorDataRenderer() {
|
||||
return (
|
||||
<div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2">
|
||||
<div className="border-t border-border/60" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const StepSeparatorDataUI = makeAssistantDataUI({
|
||||
name: "step-separator",
|
||||
render: StepSeparatorDataRenderer,
|
||||
});
|
||||
|
|
@ -39,12 +39,10 @@ import {
|
|||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||
import {
|
||||
mentionedDocumentsAtom,
|
||||
sidebarSelectedDocumentsAtom,
|
||||
} from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
|
||||
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||
import { documentsSidebarOpenAtom } from "@/atoms/documents/ui.atoms";
|
||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
|
|
@ -84,6 +82,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
|||
import {
|
||||
CONNECTOR_ICON_TO_TYPES,
|
||||
CONNECTOR_TOOL_ICON_PATHS,
|
||||
getToolDisplayName,
|
||||
getToolIcon,
|
||||
} from "@/contracts/enums/toolIcons";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
|
|
@ -91,6 +90,7 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments";
|
|||
import { useCommentsSync } from "@/hooks/use-comments-sync";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture";
|
||||
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -364,12 +364,14 @@ const ClipboardChip: FC<{ text: string; onDismiss: () => void }> = ({ text, onDi
|
|||
const Composer: FC = () => {
|
||||
// Document mention state (atoms persist across component remounts)
|
||||
const [mentionedDocuments, setMentionedDocuments] = useAtom(mentionedDocumentsAtom);
|
||||
const setSidebarDocs = useSetAtom(sidebarSelectedDocumentsAtom);
|
||||
const [showDocumentPopover, setShowDocumentPopover] = useState(false);
|
||||
const [showPromptPicker, setShowPromptPicker] = useState(false);
|
||||
const [mentionQuery, setMentionQuery] = useState("");
|
||||
const [actionQuery, setActionQuery] = useState("");
|
||||
const editorRef = useRef<InlineMentionEditorRef>(null);
|
||||
const prevMentionedDocsRef = useRef<
|
||||
Map<string, Pick<Document, "id" | "title" | "document_type">>
|
||||
>(new Map());
|
||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||
const viewportRef = useRef<Element | null>(null);
|
||||
|
|
@ -605,7 +607,6 @@ const Composer: FC = () => {
|
|||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
setSidebarDocs([]);
|
||||
|
||||
// With turnAnchor="top", ViewportSlack adds min-height to the last
|
||||
// assistant message so that scrolling-to-bottom actually positions the
|
||||
|
|
@ -652,43 +653,71 @@ const Composer: FC = () => {
|
|||
clipboardInitialText,
|
||||
aui,
|
||||
setMentionedDocuments,
|
||||
setSidebarDocs,
|
||||
threadViewportStore,
|
||||
]);
|
||||
|
||||
const handleDocumentRemove = useCallback(
|
||||
(docId: number, docType?: string) => {
|
||||
setMentionedDocuments((prev) =>
|
||||
prev.filter((doc) => !(doc.id === docId && doc.document_type === docType))
|
||||
);
|
||||
setMentionedDocuments((prev) => {
|
||||
if (!docType) {
|
||||
// Defensive fallback: keep UI in sync even when chip type is unavailable.
|
||||
return prev.filter((doc) => doc.id !== docId);
|
||||
}
|
||||
const removedKey = getMentionDocKey({ id: docId, document_type: docType });
|
||||
return prev.filter((doc) => getMentionDocKey(doc) !== removedKey);
|
||||
});
|
||||
},
|
||||
[setMentionedDocuments]
|
||||
);
|
||||
|
||||
const handleDocumentsMention = useCallback(
|
||||
(documents: Pick<Document, "id" | "title" | "document_type">[]) => {
|
||||
const existingKeys = new Set(mentionedDocuments.map((d) => `${d.document_type}:${d.id}`));
|
||||
const newDocs = documents.filter(
|
||||
(doc) => !existingKeys.has(`${doc.document_type}:${doc.id}`)
|
||||
);
|
||||
const editorMentionedDocs = editorRef.current?.getMentionedDocuments() ?? [];
|
||||
const editorDocKeys = new Set(editorMentionedDocs.map((doc) => getMentionDocKey(doc)));
|
||||
|
||||
for (const doc of newDocs) {
|
||||
for (const doc of documents) {
|
||||
const key = getMentionDocKey(doc);
|
||||
if (editorDocKeys.has(key)) continue;
|
||||
editorRef.current?.insertDocumentChip(doc);
|
||||
}
|
||||
|
||||
setMentionedDocuments((prev) => {
|
||||
const existingKeySet = new Set(prev.map((d) => `${d.document_type}:${d.id}`));
|
||||
const uniqueNewDocs = documents.filter(
|
||||
(doc) => !existingKeySet.has(`${doc.document_type}:${doc.id}`)
|
||||
);
|
||||
const existingKeySet = new Set(prev.map((d) => getMentionDocKey(d)));
|
||||
const uniqueNewDocs = documents.filter((doc) => !existingKeySet.has(getMentionDocKey(doc)));
|
||||
return [...prev, ...uniqueNewDocs];
|
||||
});
|
||||
|
||||
setMentionQuery("");
|
||||
},
|
||||
[mentionedDocuments, setMentionedDocuments]
|
||||
[setMentionedDocuments]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const editor = editorRef.current;
|
||||
const nextDocsMap = new Map(mentionedDocuments.map((doc) => [getMentionDocKey(doc), doc]));
|
||||
const prevDocsMap = prevMentionedDocsRef.current;
|
||||
|
||||
if (!editor) {
|
||||
prevMentionedDocsRef.current = nextDocsMap;
|
||||
return;
|
||||
}
|
||||
|
||||
const editorKeys = new Set(editor.getMentionedDocuments().map(getMentionDocKey));
|
||||
|
||||
for (const [key, doc] of nextDocsMap) {
|
||||
if (prevDocsMap.has(key) || editorKeys.has(key)) continue;
|
||||
editor.insertDocumentChip(doc, { removeTriggerText: false });
|
||||
}
|
||||
|
||||
for (const [key, doc] of prevDocsMap) {
|
||||
if (!nextDocsMap.has(key)) {
|
||||
editor.removeDocumentChip(doc.id, doc.document_type);
|
||||
}
|
||||
}
|
||||
|
||||
prevMentionedDocsRef.current = nextDocsMap;
|
||||
}, [mentionedDocuments]);
|
||||
|
||||
return (
|
||||
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col gap-2">
|
||||
<ChatSessionStatus
|
||||
|
|
@ -767,8 +796,6 @@ interface ComposerActionProps {
|
|||
|
||||
const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false }) => {
|
||||
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
|
||||
const sidebarDocs = useAtomValue(sidebarSelectedDocumentsAtom);
|
||||
const setDocumentsSidebarOpen = useSetAtom(documentsSidebarOpenAtom);
|
||||
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
|
||||
const [toolsPopoverOpen, setToolsPopoverOpen] = useState(false);
|
||||
const isDesktop = useMediaQuery("(min-width: 640px)");
|
||||
|
|
@ -1226,15 +1253,6 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
</AnimatePresence>
|
||||
</button>
|
||||
)}
|
||||
{sidebarDocs.length > 0 && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setDocumentsSidebarOpen(true)}
|
||||
className="rounded-full border border-border/60 bg-accent/50 px-2.5 py-1 text-xs font-medium text-foreground/80 transition-colors hover:bg-accent"
|
||||
>
|
||||
{sidebarDocs.length} {sidebarDocs.length === 1 ? "source" : "sources"} selected
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{!hasModelConfigured && (
|
||||
<div className="flex items-center gap-1.5 text-amber-600 dark:text-amber-400 text-xs">
|
||||
|
|
@ -1300,12 +1318,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
);
|
||||
};
|
||||
|
||||
/** Convert snake_case tool names to human-readable labels */
|
||||
/**
|
||||
* Friendly tool name for display in the chat UI. Delegates to the
|
||||
* shared map in ``contracts/enums/toolIcons`` so unix-style identifiers
|
||||
* (``rm``, ``ls``, ``grep`` …) and snake_cased function names render as
|
||||
* plain English (e.g. "Delete file", "List files", "Search in files").
|
||||
*/
|
||||
function formatToolName(name: string): string {
|
||||
return name
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
return getToolDisplayName(name);
|
||||
}
|
||||
|
||||
interface ToolGroup {
|
||||
|
|
|
|||
|
|
@ -1,30 +1,288 @@
|
|||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
|
||||
import { useMemo, useState } from "react";
|
||||
import {
|
||||
type ToolCallMessagePartComponent,
|
||||
useAuiState,
|
||||
} from "@assistant-ui/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||
import {
|
||||
DoomLoopApprovalToolUI,
|
||||
isDoomLoopInterrupt,
|
||||
} from "@/components/tool-ui/doom-loop-approval";
|
||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card } from "@/components/ui/card";
|
||||
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||
import {
|
||||
markActionRevertedInCache,
|
||||
useAgentActionsQuery,
|
||||
} from "@/hooks/use-agent-actions-query";
|
||||
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||
import { AppError } from "@/lib/error";
|
||||
import { isInterruptResult } from "@/lib/hitl";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
/**
|
||||
* Inline Revert button rendered on a tool card when the matching
|
||||
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
|
||||
*
|
||||
* Reads from the unified ``useAgentActionsQuery`` cache — the SAME
|
||||
* react-query cache the agent-actions sheet consumes. SSE events
|
||||
* (``data-action-log`` / ``data-action-log-updated``) and
|
||||
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
|
||||
* cache via ``setQueryData`` helpers, so the card and the sheet stay
|
||||
* in lockstep on every code path: page reload, navigation, live
|
||||
* stream, post-stream reversibility flip, and explicit revert clicks.
|
||||
*
|
||||
* Match key (in priority order):
|
||||
* 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when
|
||||
* the model streamed ``tool_call_chunks`` so the card's synthetic
|
||||
* id IS the LangChain id.
|
||||
* 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or
|
||||
* parity_v2 with provider-side chunk emission) where the card's
|
||||
* synthetic id is ``call_<run_id>`` and the LangChain id is
|
||||
* backfilled onto the part by ``tool-output-available``.
|
||||
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback
|
||||
* for cards whose synthetic id is ``call_<run_id>`` AND whose
|
||||
* ``langchainToolCallId`` never got backfilled (provider emitted
|
||||
* the tool_call as a single payload with no chunks AND streaming
|
||||
* pre-dated the ``tool-output-available langchainToolCallId``
|
||||
* backfill, e.g. older threads). Reads the parent message's
|
||||
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
|
||||
* match position-by-tool-name within the turn against the
|
||||
* action_log rows the server returned in ``created_at`` order.
|
||||
*/
|
||||
function ToolCardRevertButton({
|
||||
toolCallId,
|
||||
toolName,
|
||||
langchainToolCallId,
|
||||
}: {
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
langchainToolCallId?: string;
|
||||
}) {
|
||||
const session = useAtomValue(chatSessionStateAtom);
|
||||
const threadId = session?.threadId ?? null;
|
||||
const queryClient = useQueryClient();
|
||||
const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
|
||||
|
||||
// Parent message metadata, read via the narrowest possible
|
||||
// selectors so this card doesn't re-render on every text-delta of
|
||||
// every other part in the same message during streaming.
|
||||
//
|
||||
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
|
||||
// returned slice's identity changes. Returning ``message?.content``
|
||||
// (an array) would re-render on every token because the runtime
|
||||
// rebuilds the parts array. Returning a PRIMITIVE (the position
|
||||
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
|
||||
// when the position hasn't actually moved — which is the common
|
||||
// case during text streaming, when only ``text``/``reasoning``
|
||||
// parts are mutating and the same-toolName tool-call ordering is
|
||||
// stable. (See Vercel React rule ``rerender-defer-reads``.)
|
||||
const chatTurnId = useAuiState(({ message }) => {
|
||||
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
|
||||
return meta?.custom?.chatTurnId ?? null;
|
||||
});
|
||||
const positionInTurn = useAuiState(({ message }) => {
|
||||
const content = message?.content;
|
||||
if (!Array.isArray(content)) return -1;
|
||||
let n = -1;
|
||||
for (const part of content) {
|
||||
if (
|
||||
part &&
|
||||
typeof part === "object" &&
|
||||
(part as { type?: string }).type === "tool-call" &&
|
||||
(part as { toolName?: string }).toolName === toolName
|
||||
) {
|
||||
n += 1;
|
||||
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
});
|
||||
|
||||
const action = useMemo(() => {
|
||||
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
|
||||
// ~all parity_v2 streams and any legacy stream that backfilled
|
||||
// ``langchainToolCallId`` via ``tool-output-available``.
|
||||
const direct =
|
||||
findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
|
||||
if (direct) return direct;
|
||||
// Tier 3: position-within-turn fallback. Only kicks in when the
|
||||
// card has a synthetic ``call_<run_id>`` id AND no
|
||||
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
|
||||
// was emitted as a single non-chunked payload AND streaming
|
||||
// pre-dated the on_tool_end backfill.
|
||||
if (!chatTurnId || positionInTurn < 0) return null;
|
||||
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
|
||||
return turnSameTool[positionInTurn] ?? null;
|
||||
}, [
|
||||
findByToolCallId,
|
||||
findByChatTurnAndTool,
|
||||
toolCallId,
|
||||
langchainToolCallId,
|
||||
chatTurnId,
|
||||
toolName,
|
||||
positionInTurn,
|
||||
]);
|
||||
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||
|
||||
if (!action) return null;
|
||||
if (!action.reversible) return null;
|
||||
if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
|
||||
return null;
|
||||
if (action.is_revert_action) return null;
|
||||
if (action.error !== null && action.error !== undefined) return null;
|
||||
if (!threadId) return null;
|
||||
|
||||
const handleRevert = async () => {
|
||||
setIsReverting(true);
|
||||
try {
|
||||
const response = await agentActionsApiService.revert(threadId, action.id);
|
||||
markActionRevertedInCache(
|
||||
queryClient,
|
||||
threadId,
|
||||
action.id,
|
||||
response.new_action_id ?? null
|
||||
);
|
||||
toast.success(response.message || "Action reverted.");
|
||||
} catch (err) {
|
||||
// 503 means revert is gated off on this deployment — hide the
|
||||
// button silently rather than nagging the user. Any other error
|
||||
// is surfaced as a toast so the operator can investigate.
|
||||
if (err instanceof AppError && err.status === 503) {
|
||||
return;
|
||||
}
|
||||
const message =
|
||||
err instanceof AppError
|
||||
? err.message
|
||||
: err instanceof Error
|
||||
? err.message
|
||||
: "Failed to revert action.";
|
||||
toast.error(message);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
setConfirmOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
className="gap-1.5"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setConfirmOpen(true);
|
||||
}}
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? (
|
||||
// Spinner's typed props don't accept ``data-icon`` and
|
||||
// it renders an <output>, not an <svg>, so Button's
|
||||
// auto-sizing rule doesn't apply. Bare spinner +
|
||||
// Button's gap handle layout.
|
||||
<Spinner size="xs" />
|
||||
) : (
|
||||
<RotateCcw data-icon="inline-start" />
|
||||
)}
|
||||
Revert
|
||||
</Button>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This will undo{" "}
|
||||
<span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
|
||||
new entry to the history. Your chat is preserved — only the changes the agent made to
|
||||
your knowledge base or connected apps will be rolled back where possible.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
handleRevert();
|
||||
}}
|
||||
disabled={isReverting}
|
||||
className="gap-1.5"
|
||||
>
|
||||
{isReverting && <Spinner size="xs" />}
|
||||
Revert
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
);
|
||||
}
|
||||
|
||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||
toolName,
|
||||
argsText,
|
||||
result,
|
||||
status,
|
||||
}) => {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
/**
|
||||
* Compact tool-call card.
|
||||
*
|
||||
* shadcn composition note: we intentionally use ``Card`` as a visual
|
||||
* frame WITHOUT ``CardHeader / CardContent``. The full composition's
|
||||
* ``p-6`` padding doesn't fit a compact collapsible header that IS the
|
||||
* trigger; using ``Card`` alone preserves the rounded border, shadow,
|
||||
* and ``bg-card`` token (semantic colors) without forcing a layout
|
||||
* that doesn't fit. All status colors use semantic tokens — no manual
|
||||
* dark-mode overrides, no raw hex.
|
||||
*/
|
||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||
const { toolCallId, toolName, argsText, result, status } = props;
|
||||
// ``langchainToolCallId`` is a SurfSense-specific extension the
|
||||
// streaming pipeline attaches to the tool-call content part so
|
||||
// the Revert button can resolve its ``AgentActionLog`` row even
|
||||
// when only the LC id is known. assistant-ui's
|
||||
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
|
||||
// spreads ``{...part}`` so the prop reaches us at runtime.
|
||||
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
|
||||
|
||||
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
|
||||
const isError = status?.type === "incomplete" && status.reason === "error";
|
||||
const isRunning = status?.type === "running" || status?.type === "requires-action";
|
||||
|
||||
/*
|
||||
Per-card expansion state. Initial value is ``isRunning`` so a
|
||||
card streaming in mounts already-expanded (no flash of
|
||||
collapsed → expanded on first paint), while a card loaded from
|
||||
history (status="complete") mounts collapsed. The useEffect
|
||||
below keeps this in lockstep with this card's own ``isRunning``
|
||||
when it transitions: false → true auto-expands (e.g. a tool
|
||||
that re-runs after edit), true → false auto-collapses once the
|
||||
tool finishes. Because the dep is per-card ``isRunning`` and
|
||||
not the chat-level streaming flag, sibling cards on the same
|
||||
assistant turn each manage their own expansion independently.
|
||||
Once ``isRunning`` is false the user controls expansion via
|
||||
``onOpenChange``.
|
||||
*/
|
||||
const [isExpanded, setIsExpanded] = useState(isRunning);
|
||||
useEffect(() => {
|
||||
setIsExpanded(isRunning);
|
||||
}, [isRunning]);
|
||||
const errorData = status?.type === "incomplete" ? status.error : undefined;
|
||||
const serializedError = useMemo(
|
||||
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
|
||||
|
|
@ -50,105 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
|||
: serializedError
|
||||
: null;
|
||||
|
||||
const Icon = getToolIcon(toolName);
|
||||
const displayName = formatToolName(toolName);
|
||||
const displayName = getToolDisplayName(toolName);
|
||||
const subtitle = errorReason ?? cancelledReason;
|
||||
|
||||
return (
|
||||
<div
|
||||
<Card
|
||||
className={cn(
|
||||
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none",
|
||||
"my-4 max-w-lg overflow-hidden",
|
||||
isCancelled && "opacity-60",
|
||||
isError && "border-destructive/20 bg-destructive/5"
|
||||
isError && "border-destructive/30"
|
||||
)}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setIsExpanded((prev) => !prev)}
|
||||
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none"
|
||||
{/*
|
||||
``group`` lets the chevron (rendered as a sibling of the
|
||||
main trigger button) read the Collapsible Root's
|
||||
``data-[state=open]`` for rotation. The Collapsible is
|
||||
fully controlled via ``isExpanded`` — the useEffect
|
||||
above syncs it to ``isRunning`` so the card auto-opens
|
||||
while a tool streams in and auto-collapses once it
|
||||
finishes. We deliberately DON'T pass ``disabled`` so
|
||||
both triggers stay clickable; ``onOpenChange`` is wired
|
||||
to a setter that no-ops while ``isRunning`` (see
|
||||
``handleOpenChange`` below) which keeps the card pinned
|
||||
open mid-stream without losing keyboard / pointer
|
||||
affordance the moment streaming ends.
|
||||
*/}
|
||||
<Collapsible
|
||||
className="group"
|
||||
open={isExpanded}
|
||||
onOpenChange={(next) => {
|
||||
// Block manual collapse while the tool is still
|
||||
// streaming — otherwise a stray click on either
|
||||
// trigger would close the card and hide the live
|
||||
// ``argsText`` panel mid-run. After streaming the
|
||||
// user has full control again.
|
||||
if (isRunning) return;
|
||||
setIsExpanded(next);
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex size-8 shrink-0 items-center justify-center rounded-lg",
|
||||
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
|
||||
)}
|
||||
>
|
||||
{isError ? (
|
||||
<XCircleIcon className="size-4 text-destructive" />
|
||||
) : isCancelled ? (
|
||||
<XCircleIcon className="size-4 text-muted-foreground" />
|
||||
) : isRunning ? (
|
||||
<Icon className="size-4 text-primary animate-pulse" />
|
||||
) : (
|
||||
<CheckIcon className="size-4 text-primary" />
|
||||
)}
|
||||
</div>
|
||||
{/*
|
||||
Header row: main trigger on the left (icon + title
|
||||
col), Revert + chevron-trigger on the right as
|
||||
siblings of the main trigger. The chevron is wrapped
|
||||
in its OWN ``CollapsibleTrigger`` (Radix supports
|
||||
multiple triggers per Root) so clicking the chevron
|
||||
toggles the same state as clicking the title row.
|
||||
The Revert button stays a separate AlertDialog
|
||||
trigger and stops propagation in its onClick so it
|
||||
doesn't toggle the collapsible while opening the
|
||||
confirm dialog. Keeping these as flat siblings —
|
||||
rather than nesting Revert / chevron inside the
|
||||
title trigger — avoids invalid HTML
|
||||
(button-in-button) and lets the Revert button
|
||||
render in BOTH the collapsed and expanded states.
|
||||
*/}
|
||||
<div className="flex items-stretch transition-colors hover:bg-muted/50">
|
||||
<CollapsibleTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
|
||||
// Inset ring — Card's ``overflow-hidden`` would
|
||||
// clip an ``offset-2`` ring; ``ring-inset``
|
||||
// paints inside the button box.
|
||||
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||
"disabled:cursor-default"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex size-8 shrink-0 items-center justify-center rounded-lg",
|
||||
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
|
||||
)}
|
||||
>
|
||||
{isError ? (
|
||||
<XCircleIcon className="size-4 text-destructive" />
|
||||
) : isCancelled ? (
|
||||
<XCircleIcon className="size-4 text-muted-foreground" />
|
||||
) : isRunning ? (
|
||||
<Spinner size="sm" className="text-primary" />
|
||||
) : (
|
||||
<CheckIcon className="size-4 text-primary" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0">
|
||||
<p
|
||||
className={cn(
|
||||
"text-sm font-semibold",
|
||||
isError
|
||||
? "text-destructive"
|
||||
: isCancelled
|
||||
? "text-muted-foreground line-through"
|
||||
: "text-foreground"
|
||||
)}
|
||||
>
|
||||
{isRunning
|
||||
? displayName
|
||||
: isCancelled
|
||||
? `Cancelled: ${displayName}`
|
||||
: isError
|
||||
? `Failed: ${displayName}`
|
||||
: displayName}
|
||||
</p>
|
||||
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>}
|
||||
{cancelledReason && (
|
||||
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p>
|
||||
)}
|
||||
{errorReason && (
|
||||
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-1 min-w-0 flex-col gap-0.5">
|
||||
<div className="flex items-center gap-2">
|
||||
<p
|
||||
className={cn(
|
||||
"text-sm font-semibold truncate",
|
||||
isCancelled && "text-muted-foreground line-through",
|
||||
isError && "text-destructive"
|
||||
)}
|
||||
>
|
||||
{displayName}
|
||||
</p>
|
||||
{isRunning && <Badge variant="secondary">Running</Badge>}
|
||||
{isError && <Badge variant="destructive">Failed</Badge>}
|
||||
{isCancelled && <Badge variant="outline">Cancelled</Badge>}
|
||||
</div>
|
||||
{subtitle && (
|
||||
<p
|
||||
className={cn(
|
||||
"text-xs truncate",
|
||||
isError ? "text-destructive/80" : "text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{subtitle}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
</CollapsibleTrigger>
|
||||
|
||||
{!isRunning && (
|
||||
<div className="shrink-0 text-muted-foreground">
|
||||
{isExpanded ? (
|
||||
<ChevronDownIcon className="size-4" />
|
||||
) : (
|
||||
<ChevronUpIcon className="size-4" />
|
||||
)}
|
||||
{/*
|
||||
Right-side controls. The Revert button is
|
||||
visible whenever the matching action is
|
||||
reversible — including the collapsed state —
|
||||
but ``ToolCardRevertButton`` itself returns
|
||||
``null`` while a tool is still running because
|
||||
no action-log row exists yet, so it doesn't
|
||||
need an explicit ``isRunning`` gate here.
|
||||
*/}
|
||||
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
|
||||
<ToolCardRevertButton
|
||||
toolCallId={toolCallId}
|
||||
toolName={toolName}
|
||||
langchainToolCallId={langchainToolCallId}
|
||||
/>
|
||||
<CollapsibleTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={isExpanded ? "Collapse details" : "Expand details"}
|
||||
className={cn(
|
||||
"flex size-7 shrink-0 items-center justify-center rounded-md",
|
||||
"text-muted-foreground hover:bg-muted hover:text-foreground",
|
||||
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||
"disabled:cursor-default"
|
||||
)}
|
||||
>
|
||||
<ChevronDownIcon
|
||||
className={cn(
|
||||
"size-4 transition-transform duration-200",
|
||||
"group-data-[state=open]:rotate-180"
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
</CollapsibleTrigger>
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isExpanded && !isRunning && (
|
||||
<>
|
||||
<div className="mx-5 h-px bg-border/50" />
|
||||
<div className="px-5 py-3 space-y-3">
|
||||
{argsText && (
|
||||
<div>
|
||||
<p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p>
|
||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
||||
{argsText}
|
||||
</pre>
|
||||
{/*
|
||||
CollapsibleContent body — auto-open while streaming
|
||||
(see ``open`` prop above) so the live ``argsText``
|
||||
streams into the Inputs panel directly, no need for
|
||||
a separate "Live input" panel. Native
|
||||
``overflow-auto`` instead of ``ScrollArea`` because
|
||||
Radix's Viewport can let content bleed past
|
||||
``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
|
||||
the column wrappers guarantees ``break-all`` wraps
|
||||
correctly within the bounded ``max-w-lg`` Card.
|
||||
*/}
|
||||
<CollapsibleContent>
|
||||
<Separator />
|
||||
<div className="flex flex-col gap-3 px-5 py-3">
|
||||
{(argsText || isRunning) && (
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||
{argsText ? (
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{argsText}
|
||||
</pre>
|
||||
) : (
|
||||
// Bridges the brief gap between
|
||||
// ``tool-input-start`` (creates the
|
||||
// card, ``argsText`` undefined) and
|
||||
// the first ``tool-input-delta``.
|
||||
<p className="px-3 py-2 text-xs italic text-muted-foreground">
|
||||
Waiting for input…
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!isCancelled && result !== undefined && (
|
||||
<>
|
||||
<div className="h-px bg-border/30" />
|
||||
<div>
|
||||
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p>
|
||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
||||
{typeof result === "string" ? result : serializedResult}
|
||||
</pre>
|
||||
<Separator />
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{typeof result === "string" ? result : serializedResult}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { CheckIcon, CopyIcon, FileText, Pencil } from "lucide-react";
|
||||
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
import { type FC, useState } from "react";
|
||||
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
|
||||
interface AuthorMetadata {
|
||||
displayName: string | null;
|
||||
|
|
@ -48,6 +49,19 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
|||
|
||||
export const UserMessage: FC = () => {
|
||||
const messageId = useAuiState(({ message }) => message?.id);
|
||||
const messageText = useAuiState(({ message }) =>
|
||||
(message?.content ?? [])
|
||||
.map((part) =>
|
||||
typeof part === "object" &&
|
||||
part !== null &&
|
||||
"type" in part &&
|
||||
(part as { type?: string }).type === "text" &&
|
||||
"text" in part
|
||||
? String((part as { text?: string }).text ?? "")
|
||||
: ""
|
||||
)
|
||||
.join("")
|
||||
);
|
||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
||||
const metadata = useAuiState(({ message }) => message?.metadata);
|
||||
|
|
@ -63,22 +77,12 @@ export const UserMessage: FC = () => {
|
|||
<div className="col-start-2 min-w-0">
|
||||
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
||||
<div className="relative flex-1 min-w-0">
|
||||
{mentionedDocs && mentionedDocs.length > 0 && (
|
||||
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
|
||||
{mentionedDocs?.map((doc) => (
|
||||
<span
|
||||
key={`${doc.document_type}:${doc.id}`}
|
||||
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
|
||||
title={doc.title}
|
||||
>
|
||||
<FileText className="size-3" />
|
||||
<span className="max-w-[150px] truncate">{doc.title}</span>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
<MessagePrimitive.Parts />
|
||||
{mentionedDocs && mentionedDocs.length > 0 ? (
|
||||
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
|
||||
) : (
|
||||
<MessagePrimitive.Parts />
|
||||
)}
|
||||
</div>
|
||||
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
||||
<UserActionBar />
|
||||
|
|
@ -95,6 +99,64 @@ export const UserMessage: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const UserMessageWithMentionChips: FC<{
|
||||
text: string;
|
||||
mentionedDocs: { id: number; title: string; document_type: string }[];
|
||||
}> = ({ text, mentionedDocs }) => {
|
||||
type Segment =
|
||||
| { type: "text"; value: string; start: number }
|
||||
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
|
||||
|
||||
const tokens = mentionedDocs
|
||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||
.sort((a, b) => b.token.length - a.token.length);
|
||||
|
||||
const segments: Segment[] = [];
|
||||
let i = 0;
|
||||
let buffer = "";
|
||||
let bufferStart = 0;
|
||||
while (i < text.length) {
|
||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||
if (tokenMatch) {
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
buffer = "";
|
||||
}
|
||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||
i += tokenMatch.token.length;
|
||||
bufferStart = i;
|
||||
continue;
|
||||
}
|
||||
if (!buffer) bufferStart = i;
|
||||
buffer += text[i];
|
||||
i += 1;
|
||||
}
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="whitespace-pre-wrap break-words">
|
||||
{segments.map((segment) =>
|
||||
segment.type === "text" ? (
|
||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||
) : (
|
||||
<span
|
||||
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
|
||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
|
||||
title={segment.doc.title}
|
||||
>
|
||||
<span className="flex items-center text-muted-foreground">
|
||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||
</span>
|
||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
|
||||
const UserActionBar: FC = () => {
|
||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { DndProvider } from "react-dnd";
|
|||
import { HTML5Backend } from "react-dnd-html5-backend";
|
||||
import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { DocumentNode, type DocumentNodeDoc } from "./DocumentNode";
|
||||
import { type FolderDisplay, FolderNode } from "./FolderNode";
|
||||
|
||||
|
|
@ -17,7 +18,7 @@ interface FolderTreeViewProps {
|
|||
documents: DocumentNodeDoc[];
|
||||
expandedIds: Set<number>;
|
||||
onToggleExpand: (folderId: number) => void;
|
||||
mentionedDocIds: Set<number>;
|
||||
mentionedDocKeys: Set<string>;
|
||||
onToggleChatMention: (
|
||||
doc: { id: number; title: string; document_type: string },
|
||||
isMentioned: boolean
|
||||
|
|
@ -62,7 +63,7 @@ export function FolderTreeView({
|
|||
documents,
|
||||
expandedIds,
|
||||
onToggleExpand,
|
||||
mentionedDocIds,
|
||||
mentionedDocKeys,
|
||||
onToggleChatMention,
|
||||
onToggleFolderSelect,
|
||||
onRenameFolder,
|
||||
|
|
@ -181,7 +182,7 @@ export function FolderTreeView({
|
|||
|
||||
function compute(folderId: number): { selected: number; total: number } {
|
||||
const directDocs = (docsByFolder[folderId] ?? []).filter(isSelectable);
|
||||
let selected = directDocs.filter((d) => mentionedDocIds.has(d.id)).length;
|
||||
let selected = directDocs.filter((d) => mentionedDocKeys.has(getMentionDocKey(d))).length;
|
||||
let total = directDocs.length;
|
||||
|
||||
for (const child of foldersByParent[folderId] ?? []) {
|
||||
|
|
@ -202,7 +203,7 @@ export function FolderTreeView({
|
|||
if (states[f.id] === undefined) compute(f.id);
|
||||
}
|
||||
return states;
|
||||
}, [folders, docsByFolder, foldersByParent, mentionedDocIds]);
|
||||
}, [folders, docsByFolder, foldersByParent, mentionedDocKeys]);
|
||||
|
||||
const folderMap = useMemo(() => {
|
||||
const map: Record<number, FolderDisplay> = {};
|
||||
|
|
@ -276,7 +277,7 @@ export function FolderTreeView({
|
|||
key={`doc-${d.id}`}
|
||||
doc={d}
|
||||
depth={depth}
|
||||
isMentioned={mentionedDocIds.has(d.id)}
|
||||
isMentioned={mentionedDocKeys.has(getMentionDocKey(d))}
|
||||
onToggleChatMention={onToggleChatMention}
|
||||
onPreview={onPreviewDocument}
|
||||
onEdit={onEditDocument}
|
||||
|
|
@ -356,7 +357,7 @@ export function FolderTreeView({
|
|||
key={`doc-${d.id}`}
|
||||
doc={d}
|
||||
depth={depth}
|
||||
isMentioned={mentionedDocIds.has(d.id)}
|
||||
isMentioned={mentionedDocKeys.has(getMentionDocKey(d))}
|
||||
onToggleChatMention={onToggleChatMention}
|
||||
onPreview={onPreviewDocument}
|
||||
onEdit={onEditDocument}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {
|
|||
import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile";
|
||||
import { ShieldCheck } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
|
||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||
import {
|
||||
createTokenUsageStore,
|
||||
|
|
@ -17,10 +18,14 @@ import {
|
|||
} from "@/components/assistant-ui/token-usage-context";
|
||||
import { useAnonymousMode } from "@/contexts/anonymous-mode";
|
||||
import {
|
||||
addStepSeparator,
|
||||
addToolCall,
|
||||
appendReasoning,
|
||||
appendText,
|
||||
appendToolInputDelta,
|
||||
buildContentForUI,
|
||||
type ContentPartsState,
|
||||
endReasoning,
|
||||
FrameBatchedUpdater,
|
||||
readSSEStream,
|
||||
type ThinkingStepData,
|
||||
|
|
@ -32,7 +37,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events";
|
|||
import { FreeModelSelector } from "./free-model-selector";
|
||||
import { FreeThread } from "./free-thread";
|
||||
|
||||
const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]);
|
||||
// Render all tool calls via ToolFallback; backend keeps persisted
|
||||
// payloads bounded by summarising / truncating outputs.
|
||||
const TOOLS_WITH_UI = "all" as const;
|
||||
const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? "";
|
||||
|
||||
/** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */
|
||||
|
|
@ -125,6 +132,7 @@ export function FreeChatPage() {
|
|||
const contentPartsState: ContentPartsState = {
|
||||
contentParts: [],
|
||||
currentTextPartIndex: -1,
|
||||
currentReasoningPartIndex: -1,
|
||||
toolCallIndices: new Map(),
|
||||
};
|
||||
const { toolCallIndices } = contentPartsState;
|
||||
|
|
@ -139,6 +147,10 @@ export function FreeChatPage() {
|
|||
);
|
||||
};
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
const forceFlush = () => {
|
||||
scheduleFlush();
|
||||
batcher.flush();
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
|
|
@ -148,29 +160,74 @@ export function FreeChatPage() {
|
|||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
batcher.flush();
|
||||
case "reasoning-delta":
|
||||
appendReasoning(contentPartsState, parsed.delta);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-available":
|
||||
case "reasoning-end":
|
||||
endReasoning(contentPartsState);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "start-step":
|
||||
addStepSeparator(contentPartsState);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "finish-step":
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
{},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
forceFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-delta":
|
||||
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-available": {
|
||||
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
args: parsed.input || {},
|
||||
argsText: finalArgsText,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
TOOLS_WITH_UI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {}
|
||||
parsed.input || {},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
argsText: finalArgsText,
|
||||
});
|
||||
}
|
||||
batcher.flush();
|
||||
forceFlush();
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-output-available":
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
|
||||
batcher.flush();
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
result: parsed.output,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
forceFlush();
|
||||
break;
|
||||
|
||||
case "data-thinking-step": {
|
||||
|
|
@ -369,6 +426,7 @@ export function FreeChatPage() {
|
|||
<TokenUsageProvider store={tokenUsageStore}>
|
||||
<AssistantRuntimeProvider runtime={runtime}>
|
||||
<ThinkingStepsDataUI />
|
||||
<StepSeparatorDataUI />
|
||||
<div className="flex h-full flex-col overflow-hidden">
|
||||
<div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4">
|
||||
<FreeModelSelector />
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ import { useTranslations } from "next-intl";
|
|||
import type React from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { sidebarSelectedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import {
|
||||
mentionedDocumentsAtom,
|
||||
} from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
|
||||
|
|
@ -72,6 +74,7 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
|||
import { useDebouncedValue } from "@/hooks/use-debounced-value";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI, usePlatform } from "@/hooks/use-platform";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { foldersApiService } from "@/lib/apis/folders-api.service";
|
||||
|
|
@ -425,8 +428,11 @@ function AuthenticatedDocumentsSidebarBase({
|
|||
}, [refreshWatchedIds]);
|
||||
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
|
||||
|
||||
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
|
||||
const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]);
|
||||
const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom);
|
||||
const mentionedDocKeys = useMemo(
|
||||
() => new Set(sidebarDocs.map((d) => getMentionDocKey(d))),
|
||||
[sidebarDocs]
|
||||
);
|
||||
|
||||
// Folder state
|
||||
const [expandedFolderMap, setExpandedFolderMap] = useAtom(expandedFolderIdsAtom);
|
||||
|
|
@ -874,11 +880,12 @@ function AuthenticatedDocumentsSidebarBase({
|
|||
|
||||
const handleToggleChatMention = useCallback(
|
||||
(doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => {
|
||||
const key = getMentionDocKey(doc);
|
||||
if (isMentioned) {
|
||||
setSidebarDocs((prev) => prev.filter((d) => d.id !== doc.id));
|
||||
setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key));
|
||||
} else {
|
||||
setSidebarDocs((prev) => {
|
||||
if (prev.some((d) => d.id === doc.id)) return prev;
|
||||
if (prev.some((d) => getMentionDocKey(d) === key)) return prev;
|
||||
return [
|
||||
...prev,
|
||||
{ id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum },
|
||||
|
|
@ -909,9 +916,9 @@ function AuthenticatedDocumentsSidebarBase({
|
|||
|
||||
if (selectAll) {
|
||||
setSidebarDocs((prev) => {
|
||||
const existingIds = new Set(prev.map((d) => d.id));
|
||||
const existingDocKeys = new Set(prev.map((d) => getMentionDocKey(d)));
|
||||
const newDocs = subtreeDocs
|
||||
.filter((d) => !existingIds.has(d.id))
|
||||
.filter((d) => !existingDocKeys.has(getMentionDocKey(d)))
|
||||
.map((d) => ({
|
||||
id: d.id,
|
||||
title: d.title,
|
||||
|
|
@ -920,8 +927,8 @@ function AuthenticatedDocumentsSidebarBase({
|
|||
return newDocs.length > 0 ? [...prev, ...newDocs] : prev;
|
||||
});
|
||||
} else {
|
||||
const idsToRemove = new Set(subtreeDocs.map((d) => d.id));
|
||||
setSidebarDocs((prev) => prev.filter((d) => !idsToRemove.has(d.id)));
|
||||
const keysToRemove = new Set(subtreeDocs.map((d) => getMentionDocKey(d)));
|
||||
setSidebarDocs((prev) => prev.filter((d) => !keysToRemove.has(getMentionDocKey(d))));
|
||||
}
|
||||
},
|
||||
[treeDocuments, foldersByParent, setSidebarDocs]
|
||||
|
|
@ -1157,7 +1164,7 @@ function AuthenticatedDocumentsSidebarBase({
|
|||
documents={searchFilteredDocuments}
|
||||
expandedIds={expandedIds}
|
||||
onToggleExpand={toggleFolderExpand}
|
||||
mentionedDocIds={mentionedDocIds}
|
||||
mentionedDocKeys={mentionedDocKeys}
|
||||
onToggleChatMention={handleToggleChatMention}
|
||||
onToggleFolderSelect={handleToggleFolderSelect}
|
||||
onRenameFolder={handleRenameFolder}
|
||||
|
|
@ -1585,16 +1592,20 @@ function AnonymousDocumentsSidebar({
|
|||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [search, setSearch] = useState("");
|
||||
|
||||
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
|
||||
const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]);
|
||||
const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom);
|
||||
const mentionedDocKeys = useMemo(
|
||||
() => new Set(sidebarDocs.map((d) => getMentionDocKey(d))),
|
||||
[sidebarDocs]
|
||||
);
|
||||
|
||||
const handleToggleChatMention = useCallback(
|
||||
(doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => {
|
||||
const key = getMentionDocKey(doc);
|
||||
if (isMentioned) {
|
||||
setSidebarDocs((prev) => prev.filter((d) => d.id !== doc.id));
|
||||
setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key));
|
||||
} else {
|
||||
setSidebarDocs((prev) => {
|
||||
if (prev.some((d) => d.id === doc.id)) return prev;
|
||||
if (prev.some((d) => getMentionDocKey(d) === key)) return prev;
|
||||
return [
|
||||
...prev,
|
||||
{ id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum },
|
||||
|
|
@ -1814,7 +1825,7 @@ function AnonymousDocumentsSidebar({
|
|||
documents={searchFilteredDocs}
|
||||
expandedIds={new Set()}
|
||||
onToggleExpand={() => {}}
|
||||
mentionedDocIds={mentionedDocIds}
|
||||
mentionedDocKeys={mentionedDocKeys}
|
||||
onToggleChatMention={handleToggleChatMention}
|
||||
onToggleFolderSelect={() => {}}
|
||||
onRenameFolder={() => gate("rename folders")}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { AssistantRuntimeProvider } from "@assistant-ui/react";
|
||||
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
|
||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||
import { Navbar } from "@/components/homepage/navbar";
|
||||
import { ReportPanel } from "@/components/report-panel/report-panel";
|
||||
|
|
@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) {
|
|||
<Navbar scrolledBgClassName={navbarScrolledBg} />
|
||||
<AssistantRuntimeProvider runtime={runtime}>
|
||||
<ThinkingStepsDataUI />
|
||||
<StepSeparatorDataUI />
|
||||
<div className="flex h-screen pt-16 overflow-hidden">
|
||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||
<PublicThread footer={<PublicChatFooter shareToken={shareToken} />} />
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import Image from "next/image";
|
|||
import { type FC, type ReactNode, useState } from "react";
|
||||
import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context";
|
||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
|
||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { GenerateImageToolUI } from "@/components/tool-ui/generate-image";
|
||||
|
|
@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => {
|
|||
<MessagePrimitive.Parts
|
||||
components={{
|
||||
Text: MarkdownText,
|
||||
Reasoning: ReasoningMessagePart,
|
||||
tools: {
|
||||
by_name: {
|
||||
generate_podcast: GeneratePodcastToolUI,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
|
|
@ -77,7 +78,7 @@ function GenericApprovalCard({
|
|||
const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
|
||||
const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
const displayName = getToolDisplayName(toolName);
|
||||
|
||||
const mcpServer = interruptData.context?.mcp_server as string | undefined;
|
||||
const toolDescription = interruptData.context?.tool_description as string | undefined;
|
||||
|
|
@ -186,12 +187,11 @@ function GenericApprovalCard({
|
|||
</>
|
||||
)}
|
||||
|
||||
{/* Parameters */}
|
||||
{Object.keys(args).length > 0 && (
|
||||
<>
|
||||
<div className="mx-5 h-px bg-border/50" />
|
||||
<div className="px-5 py-4 space-y-2">
|
||||
<p className="text-xs font-medium text-muted-foreground">Parameters</p>
|
||||
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||
{phase === "pending" && isEditing ? (
|
||||
<ParamEditor
|
||||
params={editedParams}
|
||||
|
|
|
|||
|
|
@ -1,33 +1,223 @@
|
|||
import {
|
||||
BookOpen,
|
||||
Brain,
|
||||
Calendar,
|
||||
Check,
|
||||
FileEdit,
|
||||
FilePlus,
|
||||
FileText,
|
||||
FileUser,
|
||||
FileX,
|
||||
Film,
|
||||
FolderPlus,
|
||||
FolderTree,
|
||||
FolderX,
|
||||
Globe,
|
||||
ImageIcon,
|
||||
ListTodo,
|
||||
type LucideIcon,
|
||||
Mail,
|
||||
MessagesSquare,
|
||||
Move,
|
||||
Plus,
|
||||
Podcast,
|
||||
ScanLine,
|
||||
Search,
|
||||
Send,
|
||||
Trash2,
|
||||
Wrench,
|
||||
} from "lucide-react";
|
||||
|
||||
/**
|
||||
* Every tool now renders a card via ``ToolFallback``. The icon map is
|
||||
* keyed on the canonical backend tool name (registered in
|
||||
* ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown
|
||||
* names fall back to the generic ``Wrench`` icon so the card still
|
||||
* communicates "this is a tool call".
|
||||
*/
|
||||
const TOOL_ICONS: Record<string, LucideIcon> = {
|
||||
// Generators
|
||||
generate_podcast: Podcast,
|
||||
generate_video_presentation: Film,
|
||||
generate_report: FileText,
|
||||
generate_resume: FileUser,
|
||||
generate_image: ImageIcon,
|
||||
display_image: ImageIcon,
|
||||
// Web / search
|
||||
scrape_webpage: ScanLine,
|
||||
web_search: Globe,
|
||||
search_surfsense_docs: BookOpen,
|
||||
// Memory
|
||||
update_memory: Brain,
|
||||
// Filesystem (built-in deepagent + middleware)
|
||||
read_file: FileText,
|
||||
write_file: FilePlus,
|
||||
edit_file: FileEdit,
|
||||
move_file: Move,
|
||||
rm: FileX,
|
||||
rmdir: FolderX,
|
||||
mkdir: FolderPlus,
|
||||
ls: FolderTree,
|
||||
write_todos: ListTodo,
|
||||
// Calendar
|
||||
search_calendar_events: Search,
|
||||
create_calendar_event: Calendar,
|
||||
update_calendar_event: Calendar,
|
||||
delete_calendar_event: Calendar,
|
||||
// Gmail
|
||||
search_gmail: Search,
|
||||
read_gmail_email: Mail,
|
||||
create_gmail_draft: Mail,
|
||||
update_gmail_draft: FileEdit,
|
||||
send_gmail_email: Send,
|
||||
trash_gmail_email: Trash2,
|
||||
// Notion / Confluence pages
|
||||
create_notion_page: FilePlus,
|
||||
update_notion_page: FileEdit,
|
||||
delete_notion_page: FileX,
|
||||
create_confluence_page: FilePlus,
|
||||
update_confluence_page: FileEdit,
|
||||
delete_confluence_page: FileX,
|
||||
// Linear / Jira issues
|
||||
create_linear_issue: Plus,
|
||||
update_linear_issue: FileEdit,
|
||||
delete_linear_issue: Trash2,
|
||||
create_jira_issue: Plus,
|
||||
update_jira_issue: FileEdit,
|
||||
delete_jira_issue: Trash2,
|
||||
// Drive-like file connectors
|
||||
create_google_drive_file: FilePlus,
|
||||
delete_google_drive_file: FileX,
|
||||
create_dropbox_file: FilePlus,
|
||||
delete_dropbox_file: FileX,
|
||||
create_onedrive_file: FilePlus,
|
||||
delete_onedrive_file: FileX,
|
||||
// Chat connectors
|
||||
list_discord_channels: MessagesSquare,
|
||||
read_discord_messages: MessagesSquare,
|
||||
send_discord_message: Send,
|
||||
list_teams_channels: MessagesSquare,
|
||||
read_teams_messages: MessagesSquare,
|
||||
send_teams_message: Send,
|
||||
// Luma
|
||||
list_luma_events: Calendar,
|
||||
read_luma_event: Calendar,
|
||||
create_luma_event: Calendar,
|
||||
// Misc
|
||||
get_connected_accounts: Check,
|
||||
execute: Wrench,
|
||||
execute_code: Wrench,
|
||||
};
|
||||
|
||||
export function getToolIcon(name: string): LucideIcon {
|
||||
return TOOL_ICONS[name] ?? Wrench;
|
||||
}
|
||||
|
||||
/**
|
||||
* Friendly display names for tools shown in the chat UI.
|
||||
*
|
||||
* Most users aren't engineers; they shouldn't see raw unix-style
|
||||
* identifiers like ``rm`` / ``rmdir`` / ``ls`` / ``grep`` / ``glob`` or
|
||||
* snake_cased function names. The map below renders each tool with
|
||||
* plain English wording (verb + object) so non-technical users
|
||||
* understand what the agent is doing at a glance.
|
||||
*
|
||||
* Unmapped tool names fall back to a snake_case-to-Title-Case
|
||||
* conversion via :func:`getToolDisplayName`.
|
||||
*/
|
||||
const TOOL_DISPLAY_NAMES: Record<string, string> = {
|
||||
// Filesystem / knowledge base
|
||||
read_file: "Read file",
|
||||
write_file: "Write file",
|
||||
edit_file: "Edit file",
|
||||
move_file: "Move file",
|
||||
rm: "Delete file",
|
||||
rmdir: "Delete folder",
|
||||
mkdir: "Create folder",
|
||||
ls: "List files",
|
||||
glob: "Find files",
|
||||
grep: "Search in files",
|
||||
write_todos: "Plan tasks",
|
||||
save_document: "Save document",
|
||||
// Generators
|
||||
generate_podcast: "Generate podcast",
|
||||
generate_video_presentation: "Generate video presentation",
|
||||
generate_report: "Generate report",
|
||||
generate_resume: "Generate resume",
|
||||
generate_image: "Generate image",
|
||||
display_image: "Show image",
|
||||
// Web / search
|
||||
scrape_webpage: "Read webpage",
|
||||
web_search: "Search the web",
|
||||
search_surfsense_docs: "Search knowledge base",
|
||||
// Memory
|
||||
update_memory: "Update memory",
|
||||
// Calendar
|
||||
search_calendar_events: "Search calendar",
|
||||
create_calendar_event: "Create event",
|
||||
update_calendar_event: "Update event",
|
||||
delete_calendar_event: "Delete event",
|
||||
// Gmail
|
||||
search_gmail: "Search Gmail",
|
||||
read_gmail_email: "Read email",
|
||||
create_gmail_draft: "Draft email",
|
||||
update_gmail_draft: "Update draft",
|
||||
send_gmail_email: "Send email",
|
||||
trash_gmail_email: "Move email to trash",
|
||||
// Notion
|
||||
create_notion_page: "Create Notion page",
|
||||
update_notion_page: "Update Notion page",
|
||||
delete_notion_page: "Delete Notion page",
|
||||
// Confluence
|
||||
create_confluence_page: "Create Confluence page",
|
||||
update_confluence_page: "Update Confluence page",
|
||||
delete_confluence_page: "Delete Confluence page",
|
||||
// Linear
|
||||
create_linear_issue: "Create Linear issue",
|
||||
update_linear_issue: "Update Linear issue",
|
||||
delete_linear_issue: "Delete Linear issue",
|
||||
// Jira
|
||||
create_jira_issue: "Create Jira issue",
|
||||
update_jira_issue: "Update Jira issue",
|
||||
delete_jira_issue: "Delete Jira issue",
|
||||
// Drive-like file connectors
|
||||
create_google_drive_file: "Create Google Drive file",
|
||||
delete_google_drive_file: "Delete Google Drive file",
|
||||
create_dropbox_file: "Create Dropbox file",
|
||||
delete_dropbox_file: "Delete Dropbox file",
|
||||
create_onedrive_file: "Create OneDrive file",
|
||||
delete_onedrive_file: "Delete OneDrive file",
|
||||
// Discord
|
||||
list_discord_channels: "List Discord channels",
|
||||
read_discord_messages: "Read Discord messages",
|
||||
send_discord_message: "Send Discord message",
|
||||
// Teams
|
||||
list_teams_channels: "List Teams channels",
|
||||
read_teams_messages: "Read Teams messages",
|
||||
send_teams_message: "Send Teams message",
|
||||
// Luma
|
||||
list_luma_events: "List Luma events",
|
||||
read_luma_event: "Read Luma event",
|
||||
create_luma_event: "Create Luma event",
|
||||
// Misc
|
||||
get_connected_accounts: "Check connected accounts",
|
||||
execute: "Run command",
|
||||
execute_code: "Run code",
|
||||
};
|
||||
|
||||
/**
|
||||
* Format a tool's canonical (snake_case) name for display in the chat UI.
|
||||
*
|
||||
* Looks up :data:`TOOL_DISPLAY_NAMES` first; falls back to a
|
||||
* snake_case-to-Title-Case rewrite for tools that don't have a curated
|
||||
* label (e.g. dynamically registered MCP tools).
|
||||
*/
|
||||
export function getToolDisplayName(name: string): string {
|
||||
const friendly = TOOL_DISPLAY_NAMES[name];
|
||||
if (friendly) return friendly;
|
||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
export const CONNECTOR_TOOL_ICON_PATHS: Record<string, { src: string; alt: string }> = {
|
||||
gmail: { src: "/connectors/google-gmail.svg", alt: "Gmail" },
|
||||
google_calendar: { src: "/connectors/google-calendar.svg", alt: "Google Calendar" },
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
import { z } from "zod";
|
||||
|
||||
/**
|
||||
* Raw message from database (real-time sync)
|
||||
* Raw message from database (real-time sync).
|
||||
*
|
||||
* ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``)
|
||||
* can populate ``metadata.custom.chatTurnId`` on the
|
||||
* ``ThreadMessageLike`` even after the live-collab Zero re-sync. The
|
||||
* inline Revert button's ``(chat_turn_id, tool_name, position)``
|
||||
* fallback in tool-fallback.tsx depends on it.
|
||||
*/
|
||||
export const rawMessage = z.object({
|
||||
id: z.number(),
|
||||
|
|
@ -10,6 +16,7 @@ export const rawMessage = z.object({
|
|||
content: z.unknown(),
|
||||
author_id: z.string().nullable(),
|
||||
created_at: z.string(),
|
||||
turn_id: z.string().nullable().optional(),
|
||||
});
|
||||
|
||||
export type RawMessage = z.infer<typeof rawMessage>;
|
||||
|
|
|
|||
416
surfsense_web/hooks/use-agent-actions-query.ts
Normal file
416
surfsense_web/hooks/use-agent-actions-query.ts
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
"use client";
|
||||
|
||||
import { type QueryClient, useQuery } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import {
|
||||
type AgentAction,
|
||||
type AgentActionListResponse,
|
||||
agentActionsApiService,
|
||||
} from "@/lib/apis/agent-actions-api.service";
|
||||
|
||||
// =============================================================================
|
||||
// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug``
|
||||
// to ``true`` to trace the full SSE → cache → card → button pipeline in
|
||||
// the browser console. Off by default so we don't spam production. The
|
||||
// infrastructure stays in place because the underlying id-mismatch
|
||||
// failure mode is rare-but-real and surfaces only at runtime.
|
||||
// =============================================================================
|
||||
const RevertDebug = false;
|
||||
const dbg = (...args: unknown[]) => {
|
||||
if (RevertDebug && typeof window !== "undefined") {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log("[RevertDebug]", ...args);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Unified store for ``AgentActionLog`` rows scoped to one thread.
|
||||
*
|
||||
* Replaces the previous SSE side-channel atom mess
|
||||
* (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` /
|
||||
* ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook.
|
||||
* One react-query cache entry is now the single source of truth for:
|
||||
*
|
||||
* * the inline Revert button on every tool-call card
|
||||
* * the per-turn "Revert turn" button under each assistant message
|
||||
* * the edit-from-position pre-flight that decides whether to show
|
||||
* the confirmation dialog
|
||||
* * the agent-actions sheet
|
||||
*
|
||||
* The cache is hydrated by ``GET /threads/{id}/actions`` (sized to
|
||||
* 200, the server max) and updated incrementally by helpers that turn
|
||||
* SSE events / revert RPC responses into ``setQueryData`` mutations.
|
||||
* That keeps the card and the sheet in lockstep on every code path —
|
||||
* page reload, navigation, live stream, post-stream reversibility flip,
|
||||
* and explicit revert clicks.
|
||||
*/
|
||||
|
||||
export const ACTION_LOG_PAGE_SIZE = 200;
|
||||
|
||||
/** Stable react-query key for the per-thread action list. */
|
||||
export function agentActionsQueryKey(threadId: number | null) {
|
||||
return threadId !== null
|
||||
? (["agent-actions", threadId] as const)
|
||||
: (["agent-actions", "none"] as const);
|
||||
}
|
||||
|
||||
/** Subset of the SSE ``data-action-log`` payload we care about. */
|
||||
export interface ActionLogSseEvent {
|
||||
id: number;
|
||||
lc_tool_call_id: string | null;
|
||||
chat_turn_id: string | null;
|
||||
tool_name: string;
|
||||
reversible: boolean;
|
||||
reverse_descriptor_present: boolean;
|
||||
error: boolean;
|
||||
created_at: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Append or upsert a freshly-emitted ``AgentActionLog`` row into the
|
||||
* thread-scoped query cache.
|
||||
*
|
||||
* The SSE payload is a strict subset of ``AgentAction``; missing
|
||||
* fields (``args``, ``reverse_descriptor``, ``user_id``) are filled
|
||||
* with ``null`` placeholders. The next refetch (sheet open, user
|
||||
* focus, route stale) backfills them — but the inline Revert button
|
||||
* only reads the fields the SSE payload carries, so it lights up
|
||||
* immediately.
|
||||
*/
|
||||
export function applyActionLogSse(
|
||||
queryClient: QueryClient,
|
||||
threadId: number,
|
||||
searchSpaceId: number,
|
||||
event: ActionLogSseEvent
|
||||
): void {
|
||||
dbg("applyActionLogSse: incoming SSE event", {
|
||||
threadId,
|
||||
searchSpaceId,
|
||||
event,
|
||||
});
|
||||
queryClient.setQueryData<AgentActionListResponse>(
|
||||
agentActionsQueryKey(threadId),
|
||||
(prev) => {
|
||||
const placeholder: AgentAction = {
|
||||
id: event.id,
|
||||
thread_id: threadId,
|
||||
user_id: null,
|
||||
search_space_id: searchSpaceId,
|
||||
tool_name: event.tool_name,
|
||||
args: null,
|
||||
result_id: null,
|
||||
reversible: event.reversible,
|
||||
reverse_descriptor: event.reverse_descriptor_present ? {} : null,
|
||||
error: event.error ? {} : null,
|
||||
reverse_of: null,
|
||||
reverted_by_action_id: null,
|
||||
is_revert_action: false,
|
||||
tool_call_id: event.lc_tool_call_id,
|
||||
chat_turn_id: event.chat_turn_id,
|
||||
created_at: event.created_at ?? new Date().toISOString(),
|
||||
};
|
||||
if (!prev) {
|
||||
return {
|
||||
items: [placeholder],
|
||||
total: 1,
|
||||
page: 0,
|
||||
page_size: ACTION_LOG_PAGE_SIZE,
|
||||
has_more: false,
|
||||
};
|
||||
}
|
||||
const existingIdx = prev.items.findIndex((a) => a.id === event.id);
|
||||
if (existingIdx >= 0) {
|
||||
const merged = [...prev.items];
|
||||
const existing = merged[existingIdx];
|
||||
if (existing) {
|
||||
merged[existingIdx] = {
|
||||
...existing,
|
||||
reversible: event.reversible,
|
||||
tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id,
|
||||
chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id,
|
||||
};
|
||||
}
|
||||
dbg("applyActionLogSse: merged into existing entry", {
|
||||
id: event.id,
|
||||
tool_call_id: merged[existingIdx]?.tool_call_id,
|
||||
reversible: merged[existingIdx]?.reversible,
|
||||
});
|
||||
return { ...prev, items: merged };
|
||||
}
|
||||
dbg("applyActionLogSse: appended new placeholder", {
|
||||
id: event.id,
|
||||
tool_call_id: placeholder.tool_call_id,
|
||||
tool_name: placeholder.tool_name,
|
||||
reversible: placeholder.reversible,
|
||||
cacheSizeAfter: prev.items.length + 1,
|
||||
});
|
||||
// REST returns newest-first — keep that ordering when
|
||||
// the server eventually refetches by prepending.
|
||||
return {
|
||||
...prev,
|
||||
items: [placeholder, ...prev.items],
|
||||
total: prev.total + 1,
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a post-SAVEPOINT reversibility flip
|
||||
* (``data-action-log-updated`` SSE event) to the cache.
|
||||
*/
|
||||
export function applyActionLogUpdatedSse(
|
||||
queryClient: QueryClient,
|
||||
threadId: number,
|
||||
id: number,
|
||||
reversible: boolean
|
||||
): void {
|
||||
dbg("applyActionLogUpdatedSse: reversibility flip", {
|
||||
threadId,
|
||||
id,
|
||||
reversible,
|
||||
});
|
||||
queryClient.setQueryData<AgentActionListResponse>(
|
||||
agentActionsQueryKey(threadId),
|
||||
(prev) => {
|
||||
if (!prev) {
|
||||
dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", {
|
||||
threadId,
|
||||
id,
|
||||
});
|
||||
return prev;
|
||||
}
|
||||
let mutated = false;
|
||||
const items = prev.items.map((a) => {
|
||||
if (a.id !== id) return a;
|
||||
mutated = true;
|
||||
return { ...a, reversible };
|
||||
});
|
||||
if (!mutated) {
|
||||
dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", {
|
||||
threadId,
|
||||
id,
|
||||
cacheSize: prev.items.length,
|
||||
cacheIds: prev.items.map((a) => a.id),
|
||||
});
|
||||
}
|
||||
return mutated ? { ...prev, items } : prev;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistically mark ``id`` as reverted.
|
||||
*
|
||||
* Used by the inline / per-turn Revert button immediately after the
|
||||
* server returns success so the UI flips to "Reverted" without
|
||||
* waiting for a refetch. ``newActionId`` is the id of the new
|
||||
* ``is_revert_action`` row the server inserted; pass ``null`` if the
|
||||
* server didn't return it.
|
||||
*/
|
||||
export function markActionRevertedInCache(
|
||||
queryClient: QueryClient,
|
||||
threadId: number,
|
||||
id: number,
|
||||
newActionId: number | null
|
||||
): void {
|
||||
queryClient.setQueryData<AgentActionListResponse>(
|
||||
agentActionsQueryKey(threadId),
|
||||
(prev) => {
|
||||
if (!prev) return prev;
|
||||
let mutated = false;
|
||||
const items = prev.items.map((a) => {
|
||||
if (a.id !== id) return a;
|
||||
mutated = true;
|
||||
// ``-1`` is a sentinel meaning "we know it was reverted
|
||||
// but the server didn't tell us the new row's id".
|
||||
return {
|
||||
...a,
|
||||
reverted_by_action_id: newActionId ?? -1,
|
||||
};
|
||||
});
|
||||
return mutated ? { ...prev, items } : prev;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a batch of revert results (per-turn revert response) to the
|
||||
* cache. Anything in the ``reverted`` / ``already_reverted`` buckets
|
||||
* gets its ``reverted_by_action_id`` set; other rows are left alone.
|
||||
*/
|
||||
export function applyRevertTurnResultsToCache(
|
||||
queryClient: QueryClient,
|
||||
threadId: number,
|
||||
entries: Array<{ id: number; newActionId: number | null }>
|
||||
): void {
|
||||
if (entries.length === 0) return;
|
||||
queryClient.setQueryData<AgentActionListResponse>(
|
||||
agentActionsQueryKey(threadId),
|
||||
(prev) => {
|
||||
if (!prev) return prev;
|
||||
const lookup = new Map(entries.map((e) => [e.id, e.newActionId]));
|
||||
let mutated = false;
|
||||
const items = prev.items.map((a) => {
|
||||
if (!lookup.has(a.id)) return a;
|
||||
mutated = true;
|
||||
const newActionId = lookup.get(a.id) ?? null;
|
||||
return { ...a, reverted_by_action_id: newActionId ?? -1 };
|
||||
});
|
||||
return mutated ? { ...prev, items } : prev;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read-side hook used by the card, the turn button, the sheet, and
|
||||
* the edit-from-position pre-flight.
|
||||
*
|
||||
* Returns the raw query state plus convenience selectors so consumers
|
||||
* don't reach into ``data.items`` directly. ``enabled`` is the only
|
||||
* knob — pass ``false`` to keep the query dormant when the consumer
|
||||
* doesn't yet have a thread id.
|
||||
*/
|
||||
export function useAgentActionsQuery(
|
||||
threadId: number | null,
|
||||
options: { enabled?: boolean } = {}
|
||||
) {
|
||||
const enabled = (options.enabled ?? true) && threadId !== null;
|
||||
const query = useQuery({
|
||||
queryKey: agentActionsQueryKey(threadId),
|
||||
queryFn: async () => {
|
||||
dbg("useAgentActionsQuery: REST fetch START", {
|
||||
threadId,
|
||||
pageSize: ACTION_LOG_PAGE_SIZE,
|
||||
});
|
||||
const res = await agentActionsApiService.listForThread(threadId as number, {
|
||||
page: 0,
|
||||
pageSize: ACTION_LOG_PAGE_SIZE,
|
||||
});
|
||||
dbg("useAgentActionsQuery: REST fetch DONE", {
|
||||
threadId,
|
||||
total: res.total,
|
||||
returned: res.items.length,
|
||||
items: res.items.map((a) => ({
|
||||
id: a.id,
|
||||
tool_name: a.tool_name,
|
||||
tool_call_id: a.tool_call_id,
|
||||
reversible: a.reversible,
|
||||
reverted_by_action_id: a.reverted_by_action_id,
|
||||
is_revert_action: a.is_revert_action,
|
||||
})),
|
||||
});
|
||||
return res;
|
||||
},
|
||||
enabled,
|
||||
staleTime: 15 * 1000,
|
||||
});
|
||||
|
||||
const items = useMemo(() => query.data?.items ?? [], [query.data]);
|
||||
|
||||
// Index ``items`` once per change so the lookups below are O(1)
|
||||
// instead of O(N) per card per render. With the cache sized to 200
|
||||
// rows and many tool cards visible at once, the unindexed scan was
|
||||
// the hottest path on every assistant text-delta. (Vercel React
|
||||
// rule ``js-index-maps`` / ``js-set-map-lookups``.)
|
||||
const byToolCallId = useMemo(() => {
|
||||
const m = new Map<string, AgentAction>();
|
||||
for (const a of items) {
|
||||
if (a.tool_call_id) m.set(a.tool_call_id, a);
|
||||
}
|
||||
return m;
|
||||
}, [items]);
|
||||
|
||||
// Pre-grouped + pre-sorted (oldest-first, the order the agent
|
||||
// actually executed them in) so the (chat_turn_id, tool_name,
|
||||
// position) fallback in ``tool-fallback.tsx`` is also O(1) per
|
||||
// card. Excludes ``is_revert_action`` rows so the position index
|
||||
// matches the agent's original execution order.
|
||||
const byTurnAndTool = useMemo(() => {
|
||||
const m = new Map<string, AgentAction[]>();
|
||||
for (const a of items) {
|
||||
if (!a.chat_turn_id || a.is_revert_action) continue;
|
||||
const key = `${a.chat_turn_id}::${a.tool_name}`;
|
||||
const bucket = m.get(key);
|
||||
if (bucket) bucket.push(a);
|
||||
else m.set(key, [a]);
|
||||
}
|
||||
for (const bucket of m.values()) {
|
||||
bucket.sort(
|
||||
(a, b) =>
|
||||
new Date(a.created_at).getTime() - new Date(b.created_at).getTime()
|
||||
);
|
||||
}
|
||||
return m;
|
||||
}, [items]);
|
||||
|
||||
// Snapshot the cache shape when its size changes — easiest way to
|
||||
// spot when the cache is empty or stale at the moment a card
|
||||
// mounts. Tracked on a ref so we don't re-run the diff on
|
||||
// reference-equal cache reads.
|
||||
const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null);
|
||||
useEffect(() => {
|
||||
const last = lastSnapshotRef.current;
|
||||
if (!last || last.threadId !== threadId || last.size !== items.length) {
|
||||
dbg("useAgentActionsQuery: cache snapshot", {
|
||||
threadId,
|
||||
enabled,
|
||||
itemCount: items.length,
|
||||
itemKeys: items.slice(0, 8).map((a) => ({
|
||||
id: a.id,
|
||||
tool_name: a.tool_name,
|
||||
tool_call_id: a.tool_call_id,
|
||||
chat_turn_id: a.chat_turn_id,
|
||||
reversible: a.reversible,
|
||||
})),
|
||||
});
|
||||
lastSnapshotRef.current = { threadId, size: items.length };
|
||||
}
|
||||
}, [threadId, enabled, items]);
|
||||
|
||||
const findByToolCallId = useCallback(
|
||||
(toolCallId: string | null | undefined): AgentAction | null => {
|
||||
if (!toolCallId) return null;
|
||||
const found = byToolCallId.get(toolCallId) ?? null;
|
||||
if (!found && items.length > 0) {
|
||||
dbg("findByToolCallId: MISS", {
|
||||
queriedToolCallId: toolCallId,
|
||||
itemCount: items.length,
|
||||
availableToolCallIds: Array.from(byToolCallId.keys()),
|
||||
});
|
||||
}
|
||||
return found;
|
||||
},
|
||||
[byToolCallId, items.length]
|
||||
);
|
||||
|
||||
const findByChatTurnId = useCallback(
|
||||
(chatTurnId: string | null | undefined): AgentAction[] => {
|
||||
if (!chatTurnId) return [];
|
||||
// Per-turn aggregation is uncommon enough (only the
|
||||
// "Revert turn" button uses it) that re-scanning is fine;
|
||||
// indexing it would just bloat memory.
|
||||
return items.filter((a) => a.chat_turn_id === chatTurnId);
|
||||
},
|
||||
[items]
|
||||
);
|
||||
|
||||
const findByChatTurnAndTool = useCallback(
|
||||
(
|
||||
chatTurnId: string | null | undefined,
|
||||
toolName: string | null | undefined
|
||||
): AgentAction[] => {
|
||||
if (!chatTurnId || !toolName) return [];
|
||||
return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? [];
|
||||
},
|
||||
[byTurnAndTool]
|
||||
);
|
||||
|
||||
return {
|
||||
...query,
|
||||
items,
|
||||
findByToolCallId,
|
||||
findByChatTurnId,
|
||||
findByChatTurnAndTool,
|
||||
};
|
||||
}
|
||||
|
|
@ -31,6 +31,14 @@ export function useMessagesSync(
|
|||
content: msg.content,
|
||||
author_id: msg.authorId ?? null,
|
||||
created_at: new Date(msg.createdAt).toISOString(),
|
||||
// Forward the per-turn correlation id so post-stream Zero
|
||||
// re-syncs preserve ``metadata.custom.chatTurnId`` on the
|
||||
// converted ``ThreadMessageLike``. Without this the inline
|
||||
// Revert button's ``(chat_turn_id, tool_name, position)``
|
||||
// fallback breaks the moment Zero overwrites the messages
|
||||
// state after a live stream completes (see
|
||||
// ``handleSyncedMessagesUpdate`` in the chat page).
|
||||
turn_id: msg.turnId ?? null,
|
||||
}));
|
||||
|
||||
onMessagesUpdateRef.current(mapped);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({
|
|||
reverse_of: z.number().nullable(),
|
||||
reverted_by_action_id: z.number().nullable(),
|
||||
is_revert_action: z.boolean(),
|
||||
// Correlation ids added in migration 135. The LangChain
|
||||
// ``tool_call_id`` joins this row to the chat tool card via the
|
||||
// ``data-action-log.lc_tool_call_id`` SSE event, and
|
||||
// ``chat_turn_id`` keys the per-turn revert endpoint.
|
||||
tool_call_id: z.string().nullable().optional(),
|
||||
chat_turn_id: z.string().nullable().optional(),
|
||||
created_at: z.string(),
|
||||
});
|
||||
|
||||
|
|
@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({
|
|||
|
||||
export type RevertResponse = z.infer<typeof RevertResponseSchema>;
|
||||
|
||||
// Per-turn batch revert. The route never returns whole-batch 4xx;
|
||||
// partial success is the common case and surfaced as
|
||||
// ``status === "partial"`` with a per-action result list.
|
||||
const RevertTurnActionResultSchema = z.object({
|
||||
action_id: z.number(),
|
||||
tool_name: z.string(),
|
||||
status: z.enum([
|
||||
"reverted",
|
||||
"already_reverted",
|
||||
"not_reversible",
|
||||
"permission_denied",
|
||||
"failed",
|
||||
"skipped",
|
||||
]),
|
||||
message: z.string().nullable().optional(),
|
||||
new_action_id: z.number().nullable().optional(),
|
||||
error: z.string().nullable().optional(),
|
||||
});
|
||||
|
||||
export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>;
|
||||
|
||||
const RevertTurnResponseSchema = z.object({
|
||||
status: z.enum(["ok", "partial"]),
|
||||
chat_turn_id: z.string(),
|
||||
total: z.number(),
|
||||
reverted: z.number(),
|
||||
already_reverted: z.number(),
|
||||
not_reversible: z.number(),
|
||||
// ``permission_denied`` and ``skipped`` are first-class counters so
|
||||
// ``total === reverted + already_reverted +
|
||||
// not_reversible + permission_denied + failed + skipped`` always
|
||||
// holds. ``.default(0)`` keeps the schema backwards-compatible
|
||||
// with older deployments that haven't shipped the response model
|
||||
// update yet.
|
||||
permission_denied: z.number().default(0),
|
||||
failed: z.number(),
|
||||
skipped: z.number().default(0),
|
||||
results: z.array(RevertTurnActionResultSchema),
|
||||
});
|
||||
|
||||
export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>;
|
||||
|
||||
class AgentActionsApiService {
|
||||
listForThread = async (
|
||||
threadId: number,
|
||||
|
|
@ -59,6 +107,14 @@ class AgentActionsApiService {
|
|||
{ body: {} }
|
||||
);
|
||||
};
|
||||
|
||||
revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => {
|
||||
return baseApiService.post(
|
||||
`/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`,
|
||||
RevertTurnResponseSchema,
|
||||
{ body: {} }
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export const agentActionsApiService = new AgentActionsApiService();
|
||||
|
|
|
|||
8
surfsense_web/lib/chat/mention-doc-key.ts
Normal file
8
surfsense_web/lib/chat/mention-doc-key.ts
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
type MentionKeyInput = {
|
||||
id: number;
|
||||
document_type?: string | null;
|
||||
};
|
||||
|
||||
export function getMentionDocKey(doc: MentionKeyInput): string {
|
||||
return `${doc.document_type ?? "UNKNOWN"}:${doc.id}`;
|
||||
}
|
||||
|
|
@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
|||
}
|
||||
|
||||
const metadata =
|
||||
msg.author_id || msg.token_usage
|
||||
msg.author_id || msg.token_usage || msg.turn_id
|
||||
? {
|
||||
custom: {
|
||||
...(msg.author_id && {
|
||||
|
|
@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
|||
},
|
||||
}),
|
||||
...(msg.token_usage && { usage: msg.token_usage }),
|
||||
// Surface ``chat_turn_id`` so the assistant message
|
||||
// footer can scope its "Revert turn" button to just
|
||||
// this turn's actions. Null on legacy rows.
|
||||
...(msg.turn_id && { chatTurnId: msg.turn_id }),
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
|
|
|||
|
|
@ -9,21 +9,59 @@ export interface ThinkingStepData {
|
|||
|
||||
export type ContentPart =
|
||||
| { type: "text"; text: string }
|
||||
| { type: "reasoning"; text: string }
|
||||
| {
|
||||
type: "tool-call";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
/**
|
||||
* Live / finalized JSON text for the tool's input arguments.
|
||||
*
|
||||
* - During streaming: accumulated partial JSON text from
|
||||
* ``tool-input-delta`` events (may be invalid JSON
|
||||
* mid-stream). assistant-ui's argsText parser tolerates
|
||||
* invalid JSON gracefully (changelog 0.7.32 / 0.7.78).
|
||||
* - On completion (``tool-input-available``): replaced with
|
||||
* ``JSON.stringify(input, null, 2)`` so the post-stream
|
||||
* card renders pretty-printed JSON instead of the
|
||||
* model's possibly-fragmented formatting.
|
||||
*
|
||||
* Per assistant-ui ``ThreadMessageLike`` precedence
|
||||
* (changelog 0.11.6 ``d318c83``), when ``argsText`` is
|
||||
* supplied it wins over ``JSON.stringify(args)``.
|
||||
*/
|
||||
argsText?: string;
|
||||
/**
|
||||
* Authoritative LangChain ``tool_call.id`` propagated by the backend
|
||||
* via ``langchainToolCallId`` on tool-input-start/available and
|
||||
* tool-output-available events. Used to join a card to the
|
||||
* matching ``AgentActionLog`` row exposed by
|
||||
* ``GET /threads/{id}/actions`` and the streamed
|
||||
* ``data-action-log`` events.
|
||||
*/
|
||||
langchainToolCallId?: string;
|
||||
}
|
||||
| {
|
||||
type: "data-thinking-steps";
|
||||
data: { steps: ThinkingStepData[] };
|
||||
}
|
||||
| {
|
||||
/**
|
||||
* Between-step separator. Pushed by `addStepSeparator` when
|
||||
* a `start-step` SSE event arrives AFTER the message already
|
||||
* has non-step content. Rendered by `StepSeparatorDataUI`
|
||||
* (see assistant-ui/step-separator.tsx).
|
||||
*/
|
||||
type: "data-step-separator";
|
||||
data: { stepIndex: number };
|
||||
};
|
||||
|
||||
export interface ContentPartsState {
|
||||
contentParts: ContentPart[];
|
||||
currentTextPartIndex: number;
|
||||
currentReasoningPartIndex: number;
|
||||
toolCallIndices: Map<string, number>;
|
||||
}
|
||||
|
||||
|
|
@ -74,6 +112,9 @@ export function updateThinkingSteps(
|
|||
if (state.currentTextPartIndex >= 0) {
|
||||
state.currentTextPartIndex += 1;
|
||||
}
|
||||
if (state.currentReasoningPartIndex >= 0) {
|
||||
state.currentReasoningPartIndex += 1;
|
||||
}
|
||||
for (const [id, idx] of state.toolCallIndices) {
|
||||
state.toolCallIndices.set(id, idx + 1);
|
||||
}
|
||||
|
|
@ -131,6 +172,12 @@ export class FrameBatchedUpdater {
|
|||
}
|
||||
|
||||
export function appendText(state: ContentPartsState, delta: string): void {
|
||||
// First text delta after a reasoning block: close the reasoning so
|
||||
// the assistant-ui renderer treats them as separate parts (the
|
||||
// reasoning block collapses; the answer streams below).
|
||||
if (state.currentReasoningPartIndex >= 0) {
|
||||
state.currentReasoningPartIndex = -1;
|
||||
}
|
||||
if (
|
||||
state.currentTextPartIndex >= 0 &&
|
||||
state.contentParts[state.currentTextPartIndex]?.type === "text"
|
||||
|
|
@ -143,39 +190,161 @@ export function appendText(state: ContentPartsState, delta: string): void {
|
|||
}
|
||||
}
|
||||
|
||||
export function appendReasoning(state: ContentPartsState, delta: string): void {
|
||||
// Symmetric to appendText: open a fresh reasoning block on first
|
||||
// delta, then accumulate into it. ``endReasoning`` simply closes
|
||||
// the active block; subsequent reasoning deltas would open a new
|
||||
// one (matching ``text-start/end`` semantics on the wire).
|
||||
if (state.currentTextPartIndex >= 0) {
|
||||
state.currentTextPartIndex = -1;
|
||||
}
|
||||
if (
|
||||
state.currentReasoningPartIndex >= 0 &&
|
||||
state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning"
|
||||
) {
|
||||
(
|
||||
state.contentParts[state.currentReasoningPartIndex] as {
|
||||
type: "reasoning";
|
||||
text: string;
|
||||
}
|
||||
).text += delta;
|
||||
} else {
|
||||
state.contentParts.push({ type: "reasoning", text: delta });
|
||||
state.currentReasoningPartIndex = state.contentParts.length - 1;
|
||||
}
|
||||
}
|
||||
|
||||
export function endReasoning(state: ContentPartsState): void {
|
||||
state.currentReasoningPartIndex = -1;
|
||||
}
|
||||
|
||||
export function addStepSeparator(state: ContentPartsState): void {
|
||||
// Push a divider between consecutive model steps within a single
|
||||
// assistant turn. We only emit it when the message already has
|
||||
// non-step content (so the FIRST step of a turn doesn't
|
||||
// generate a leading separator) and when the previous part isn't
|
||||
// itself a separator (defensive against duplicate `start-step`
|
||||
// events).
|
||||
const hasContent = state.contentParts.some(
|
||||
(p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call"
|
||||
);
|
||||
if (!hasContent) return;
|
||||
const last = state.contentParts[state.contentParts.length - 1];
|
||||
if (last && last.type === "data-step-separator") return;
|
||||
|
||||
const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length;
|
||||
state.contentParts.push({ type: "data-step-separator", data: { stepIndex } });
|
||||
state.currentTextPartIndex = -1;
|
||||
state.currentReasoningPartIndex = -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Allowlist of tool names that should produce a UI tool card. The
|
||||
* sentinel ``"all"`` matches every tool — we dropped the legacy
|
||||
* ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the
|
||||
* generic ``ToolFallback``. The backend's ``format_thinking_step``
|
||||
* summarisation and the defensive ``result_length``-only default for
|
||||
* unknown tools keep persisted message JSON from ballooning.
|
||||
*/
|
||||
export type ToolUIGate = Set<string> | "all";
|
||||
|
||||
function _toolPasses(gate: ToolUIGate, toolName: string): boolean {
|
||||
return gate === "all" || gate.has(toolName);
|
||||
}
|
||||
|
||||
export function addToolCall(
|
||||
state: ContentPartsState,
|
||||
toolsWithUI: Set<string>,
|
||||
toolsWithUI: ToolUIGate,
|
||||
toolCallId: string,
|
||||
toolName: string,
|
||||
args: Record<string, unknown>,
|
||||
force = false
|
||||
force = false,
|
||||
langchainToolCallId?: string
|
||||
): void {
|
||||
if (force || toolsWithUI.has(toolName)) {
|
||||
if (force || _toolPasses(toolsWithUI, toolName)) {
|
||||
state.contentParts.push({
|
||||
type: "tool-call",
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
...(langchainToolCallId ? { langchainToolCallId } : {}),
|
||||
});
|
||||
state.toolCallIndices.set(toolCallId, state.contentParts.length - 1);
|
||||
state.currentTextPartIndex = -1;
|
||||
state.currentReasoningPartIndex = -1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse-lookup helper used by the SSE ``data-action-log`` handler:
|
||||
* given the LangChain ``tool_call.id`` (set on the content part as
|
||||
* ``langchainToolCallId``), return the synthetic ``toolCallId`` that
|
||||
* the chat tool card uses (``call_<run-id>``). Returns ``null`` when no
|
||||
* matching tool card has been seen yet — the action is still recorded
|
||||
* in the LC-id-keyed atom so the card can pick it up when it eventually
|
||||
* arrives.
|
||||
*/
|
||||
export function findToolCallIdByLcId(
|
||||
state: ContentPartsState,
|
||||
lcToolCallId: string
|
||||
): string | null {
|
||||
for (const part of state.contentParts) {
|
||||
if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) {
|
||||
return part.toolCallId;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function updateToolCall(
|
||||
state: ContentPartsState,
|
||||
toolCallId: string,
|
||||
update: { args?: Record<string, unknown>; result?: unknown }
|
||||
update: {
|
||||
args?: Record<string, unknown>;
|
||||
argsText?: string;
|
||||
result?: unknown;
|
||||
langchainToolCallId?: string;
|
||||
}
|
||||
): void {
|
||||
const index = state.toolCallIndices.get(toolCallId);
|
||||
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
|
||||
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
|
||||
if (update.args) tc.args = update.args;
|
||||
// ``!== undefined`` (NOT a truthy check): an explicit empty
|
||||
// string CAN clear, and a finalization with
|
||||
// ``JSON.stringify({}, null, 2) === "{}"`` (truthy but
|
||||
// represents an empty-input call) still applies.
|
||||
if (update.argsText !== undefined) tc.argsText = update.argsText;
|
||||
if (update.result !== undefined) tc.result = update.result;
|
||||
// Only backfill langchainToolCallId if not already set — the
|
||||
// authoritative ``on_tool_end`` value should override an earlier
|
||||
// best-effort match, but a NULL late-arriving value should not
|
||||
// blow away a known good early one.
|
||||
if (update.langchainToolCallId && !tc.langchainToolCallId) {
|
||||
tc.langchainToolCallId = update.langchainToolCallId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Append a streamed args-delta chunk to the active tool call's
|
||||
* ``argsText``. No-ops when no card has been registered yet for the
|
||||
* given ``toolCallId`` (the matching ``tool-input-start`` either lost
|
||||
* the wire race or this id never had a card — either way the deltas
|
||||
* have nowhere safe to land).
|
||||
*/
|
||||
export function appendToolInputDelta(
|
||||
state: ContentPartsState,
|
||||
toolCallId: string,
|
||||
delta: string
|
||||
): void {
|
||||
const idx = state.toolCallIndices.get(toolCallId);
|
||||
if (idx === undefined) return;
|
||||
const tc = state.contentParts[idx];
|
||||
if (tc?.type !== "tool-call") return;
|
||||
tc.argsText = (tc.argsText ?? "") + delta;
|
||||
}
|
||||
|
||||
function _hasInterruptResult(part: ContentPart): boolean {
|
||||
if (part.type !== "tool-call") return false;
|
||||
const r = (part as { result?: unknown }).result;
|
||||
|
|
@ -184,13 +353,15 @@ function _hasInterruptResult(part: ContentPart): boolean {
|
|||
|
||||
export function buildContentForUI(
|
||||
state: ContentPartsState,
|
||||
toolsWithUI: Set<string>
|
||||
toolsWithUI: ToolUIGate
|
||||
): ThreadMessageLike["content"] {
|
||||
const filtered = state.contentParts.filter((part) => {
|
||||
if (part.type === "text") return part.text.length > 0;
|
||||
if (part.type === "reasoning") return part.text.length > 0;
|
||||
if (part.type === "tool-call")
|
||||
return toolsWithUI.has(part.toolName) || _hasInterruptResult(part);
|
||||
return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part);
|
||||
if (part.type === "data-thinking-steps") return true;
|
||||
if (part.type === "data-step-separator") return true;
|
||||
return false;
|
||||
});
|
||||
return filtered.length > 0
|
||||
|
|
@ -200,20 +371,28 @@ export function buildContentForUI(
|
|||
|
||||
export function buildContentForPersistence(
|
||||
state: ContentPartsState,
|
||||
toolsWithUI: Set<string>
|
||||
toolsWithUI: ToolUIGate
|
||||
): unknown[] {
|
||||
const parts: unknown[] = [];
|
||||
|
||||
for (const part of state.contentParts) {
|
||||
if (part.type === "text" && part.text.length > 0) {
|
||||
parts.push(part);
|
||||
} else if (part.type === "reasoning" && part.text.length > 0) {
|
||||
// Persist reasoning blocks so a chat reload re-renders the
|
||||
// collapsed thinking section instead of
|
||||
// silently dropping it (mirrors the data-thinking-steps
|
||||
// branch above).
|
||||
parts.push(part);
|
||||
} else if (
|
||||
part.type === "tool-call" &&
|
||||
(toolsWithUI.has(part.toolName) || _hasInterruptResult(part))
|
||||
(_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part))
|
||||
) {
|
||||
parts.push(part);
|
||||
} else if (part.type === "data-thinking-steps") {
|
||||
parts.push(part);
|
||||
} else if (part.type === "data-step-separator") {
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -221,23 +400,134 @@ export function buildContentForPersistence(
|
|||
}
|
||||
|
||||
export type SSEEvent =
|
||||
| { type: "text-delta"; delta: string }
|
||||
| { type: "tool-input-start"; toolCallId: string; toolName: string }
|
||||
| { type: "start"; messageId?: string }
|
||||
| { type: "finish" }
|
||||
| { type: "start-step" }
|
||||
| { type: "finish-step" }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-delta"; id?: string; delta: string }
|
||||
| { type: "text-end"; id: string }
|
||||
| { type: "reasoning-start"; id: string }
|
||||
| { type: "reasoning-delta"; id?: string; delta: string }
|
||||
| { type: "reasoning-end"; id: string }
|
||||
| {
|
||||
type: "tool-input-start";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
/** Authoritative LangChain ``tool_call.id``. Optional. */
|
||||
langchainToolCallId?: string;
|
||||
}
|
||||
| {
|
||||
/**
|
||||
* Live tool-call argument delta. Concatenated into
|
||||
* ``argsText`` on the matching ``tool-call`` content part
|
||||
* by ``appendToolInputDelta``. parity_v2 only — the legacy
|
||||
* code path emits ``tool-input-available`` without prior
|
||||
* deltas.
|
||||
*/
|
||||
type: "tool-input-delta";
|
||||
toolCallId: string;
|
||||
inputTextDelta: string;
|
||||
}
|
||||
| {
|
||||
type: "tool-input-available";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
input: Record<string, unknown>;
|
||||
langchainToolCallId?: string;
|
||||
}
|
||||
| {
|
||||
type: "tool-output-available";
|
||||
toolCallId: string;
|
||||
output: Record<string, unknown>;
|
||||
/** Authoritative LangChain ``tool_call.id`` extracted from
|
||||
* ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards
|
||||
* that didn't get the id at tool-input-start time. */
|
||||
langchainToolCallId?: string;
|
||||
}
|
||||
| { type: "data-thinking-step"; data: ThinkingStepData }
|
||||
| { type: "data-thread-title-update"; data: { threadId: number; title: string } }
|
||||
| { type: "data-interrupt-request"; data: Record<string, unknown> }
|
||||
| { type: "data-documents-updated"; data: Record<string, unknown> }
|
||||
| {
|
||||
/**
|
||||
* A freshly committed AgentActionLog row. Frontend stores
|
||||
* this in a Map keyed off ``lc_tool_call_id`` so the chat
|
||||
* tool card can light up its Revert button.
|
||||
*/
|
||||
type: "data-action-log";
|
||||
data: {
|
||||
id: number;
|
||||
lc_tool_call_id: string | null;
|
||||
chat_turn_id: string | null;
|
||||
tool_name: string;
|
||||
reversible: boolean;
|
||||
reverse_descriptor_present: boolean;
|
||||
created_at: string | null;
|
||||
error: boolean;
|
||||
};
|
||||
}
|
||||
| {
|
||||
/**
|
||||
* Reversibility flipped (filesystem op SAVEPOINT committed;
|
||||
* cf. ``kb_persistence._dispatch_reversibility_update``).
|
||||
*/
|
||||
type: "data-action-log-updated";
|
||||
data: { id: number; reversible: boolean };
|
||||
}
|
||||
| {
|
||||
/**
|
||||
* Emitted at the start of every stream so the frontend can
|
||||
* stamp the per-turn correlation id onto the in-flight
|
||||
* assistant message and replay it via
|
||||
* ``appendMessage``. Pure-text turns never produce
|
||||
* action-log events; this event guarantees the frontend
|
||||
* always learns the turn id.
|
||||
*/
|
||||
type: "data-turn-info";
|
||||
data: { chat_turn_id: string };
|
||||
}
|
||||
| {
|
||||
/**
|
||||
* Best-effort revert pass that ran BEFORE this regeneration.
|
||||
* Per-action results are forwarded to the UI so the user
|
||||
* can see which downstream actions were rolled
|
||||
* back vs which couldn't be undone.
|
||||
*/
|
||||
type: "data-revert-results";
|
||||
data: {
|
||||
status: "ok" | "partial";
|
||||
chat_turn_ids: string[];
|
||||
total: number;
|
||||
reverted: number;
|
||||
already_reverted: number;
|
||||
not_reversible: number;
|
||||
/**
|
||||
* ``permission_denied`` and ``skipped`` are first-class
|
||||
* counters so the response invariant
|
||||
* ``total === sum(counters)`` always holds. Optional
|
||||
* for forward compatibility with older backends; the
|
||||
* frontend treats missing values as ``0``.
|
||||
*/
|
||||
permission_denied?: number;
|
||||
failed: number;
|
||||
skipped?: number;
|
||||
results: Array<{
|
||||
action_id: number;
|
||||
tool_name: string;
|
||||
status:
|
||||
| "reverted"
|
||||
| "already_reverted"
|
||||
| "not_reversible"
|
||||
| "permission_denied"
|
||||
| "failed"
|
||||
| "skipped";
|
||||
message?: string | null;
|
||||
new_action_id?: number | null;
|
||||
error?: string | null;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: "data-token-usage";
|
||||
data: {
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@ export interface MessageRecord {
|
|||
author_display_name?: string | null;
|
||||
author_avatar_url?: string | null;
|
||||
token_usage?: TokenUsageSummary | null;
|
||||
// Per-turn correlation id from ``configurable.turn_id`` at streaming
|
||||
// time (added in migration 136). Used by the per-turn revert
|
||||
// endpoint and edit-from-arbitrary-position. Nullable on legacy
|
||||
// rows that predate the column.
|
||||
turn_id?: string | null;
|
||||
}
|
||||
|
||||
export interface ThreadListResponse {
|
||||
|
|
@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
|
|||
|
||||
/**
|
||||
* Append a message to a thread.
|
||||
*
|
||||
* ``turn_id`` is the per-turn correlation id streamed by the backend
|
||||
* via ``data-turn-info``. Persisting it lets later edits locate the
|
||||
* matching LangGraph checkpoint without HumanMessage scanning. Older
|
||||
* callers can still omit it for back-compat.
|
||||
*/
|
||||
export async function appendMessage(
|
||||
threadId: number,
|
||||
message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown }
|
||||
message: {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: unknown;
|
||||
token_usage?: unknown;
|
||||
turn_id?: string | null;
|
||||
}
|
||||
): Promise<MessageRecord> {
|
||||
return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, {
|
||||
body: message,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,39 @@
|
|||
import { loader } from "fumadocs-core/source";
|
||||
import { icons } from "lucide-react";
|
||||
import {
|
||||
BookOpen,
|
||||
ClipboardCheck,
|
||||
Compass,
|
||||
Container,
|
||||
Download,
|
||||
FlaskConical,
|
||||
Heart,
|
||||
Unplug,
|
||||
Wrench,
|
||||
} from "lucide-react";
|
||||
import { createElement } from "react";
|
||||
import { docs } from "@/.source/server";
|
||||
|
||||
/** Explicit whitelist of Lucide icons used in docs frontmatter / meta.json.
|
||||
* Importing the full `icons` barrel would pull every Lucide icon (~1 400 SVGs)
|
||||
* into the docs bundle even though only a handful are referenced. Add new icons
|
||||
* here as docs pages are added.
|
||||
*/
|
||||
const DOCS_ICONS: Record<string, React.ComponentType> = {
|
||||
BookOpen,
|
||||
ClipboardCheck,
|
||||
Compass,
|
||||
Container,
|
||||
Download,
|
||||
FlaskConical,
|
||||
Heart,
|
||||
Unplug,
|
||||
Wrench,
|
||||
};
|
||||
|
||||
export const source = loader({
|
||||
baseUrl: "/docs",
|
||||
source: docs.toFumadocsSource(),
|
||||
icon(icon) {
|
||||
if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]);
|
||||
if (icon && icon in DOCS_ICONS) return createElement(DOCS_ICONS[icon]);
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages")
|
|||
threadId: number().from("thread_id"),
|
||||
authorId: string().optional().from("author_id"),
|
||||
createdAt: number().from("created_at"),
|
||||
// Per-turn correlation id sourced from ``configurable.turn_id``
|
||||
// at streaming time. Required by the inline Revert button's
|
||||
// (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx
|
||||
// — without it the live-collab Zero sync would clobber the
|
||||
// metadata we set during streaming and the button would vanish
|
||||
// the moment Zero re-syncs after the stream finishes.
|
||||
turnId: string().optional().from("turn_id"),
|
||||
})
|
||||
.primaryKey("id");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue