mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 14:52:39 +02:00
Merge remote-tracking branch 'upstream/dev' into feature/multi-agent
This commit is contained in:
commit
5d3b8b9ca9
83 changed files with 10514 additions and 638 deletions
|
|
@ -728,7 +728,8 @@ def _build_compiled_agent_blocking(
|
|||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
# Tools owned by the standard deepagents middleware stack.
|
||||
# Tools owned by the standard deepagents middleware stack and the
|
||||
# SurfSense filesystem extension.
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
|
|
@ -739,6 +740,14 @@ def _build_compiled_agent_blocking(
|
|||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
|
|
@ -767,25 +776,51 @@ def _build_compiled_agent_blocking(
|
|||
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||
# reject decision into innocent calls.
|
||||
# 2. ``connector_synthesized`` — deny rules for tools whose required
|
||||
# connector is not connected to this space. Overrides #1.
|
||||
# 3. (future) user-defined rules from ``agent_permission_rules`` table
|
||||
# via the Agent Permissions UI. Loaded last so they override both.
|
||||
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
|
||||
# the agent is operating against the user's real disk. Cloud mode
|
||||
# has full revision-based revert via ``revert_service``, but
|
||||
# desktop mode hits disk immediately with no undo, so an
|
||||
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
|
||||
# ``write_file`` is unrecoverable. This layer is forced on in
|
||||
# desktop mode regardless of ``enable_permission`` because the
|
||||
# safety net is non-negotiable.
|
||||
# 3. ``connector_synthesized`` — deny rules for tools whose required
|
||||
# connector is not connected to this space. Overrides #1/#2.
|
||||
# 4. (future) user-defined rules from ``agent_permission_rules`` table
|
||||
# via the Agent Permissions UI. Loaded last so they override all.
|
||||
permission_mw: PermissionMiddleware | None = None
|
||||
if flags.enable_permission and not flags.disable_new_agent_stack:
|
||||
synthesized = _synthesize_connector_deny_rules(
|
||||
available_connectors=available_connectors,
|
||||
enabled_tool_names={t.name for t in tools},
|
||||
)
|
||||
permission_mw = PermissionMiddleware(
|
||||
rulesets=[
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||
# Build the middleware whenever it has work to do: either the user
|
||||
# opted into the rule engine, OR we're in desktop mode and need the
|
||||
# safety rules unconditionally.
|
||||
if permission_enabled or is_desktop_fs:
|
||||
rulesets: list[Ruleset] = [
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
]
|
||||
if is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
Ruleset(rules=synthesized, origin="connector_synthesized"),
|
||||
],
|
||||
)
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
if permission_enabled:
|
||||
synthesized = _synthesize_connector_deny_rules(
|
||||
available_connectors=available_connectors,
|
||||
enabled_tool_names={t.name for t in tools},
|
||||
)
|
||||
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||
permission_mw = PermissionMiddleware(rulesets=rulesets)
|
||||
|
||||
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
||||
# table is migrated. When enabled, persists one row per tool call
|
||||
|
|
@ -942,6 +977,7 @@ def _build_compiled_agent_blocking(
|
|||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
|
|||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
|
|
@ -86,6 +87,15 @@ class AgentFeatureFlags:
|
|||
False # Backend ships before UI; route returns 503 until this flips
|
||||
)
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured
|
||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||
# ship unconditionally because they're forward-compatible.
|
||||
enable_stream_parity_v2: bool = False
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
|
|
@ -139,6 +149,10 @@ class AgentFeatureFlags:
|
|||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||
# Streaming parity v2
|
||||
enable_stream_parity_v2=_env_bool(
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
||||
),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
|
|
|
|||
|
|
@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
|||
|
||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||
|
|
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
|
|||
)
|
||||
|
||||
|
||||
class PendingMove(TypedDict):
|
||||
"""A staged move_file operation pending end-of-turn commit."""
|
||||
class PendingMove(TypedDict, total=False):
|
||||
"""A staged move_file operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||
written before the snapshot/revert pipeline was wired up; new entries
|
||||
always include it so the persistence body can resolve an action_id.
|
||||
"""
|
||||
|
||||
source: str
|
||||
dest: str
|
||||
overwrite: bool
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class PendingDelete(TypedDict, total=False):
|
||||
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is required for new entries (it's the binding key used
|
||||
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||
``total=False`` only to tolerate older checkpoint payloads.
|
||||
"""
|
||||
|
||||
path: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class KbPriorityEntry(TypedDict, total=False):
|
||||
|
|
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||
|
||||
staged_dir_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||
|
||||
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||
"""
|
||||
|
||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||
|
||||
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||
|
||||
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||
``tool_call_id`` per occurrence so snapshot binding works).
|
||||
"""
|
||||
|
||||
pending_dir_deletes: NotRequired[
|
||||
Annotated[list[PendingDelete], _list_append_reducer]
|
||||
]
|
||||
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||
|
||||
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||
for) before issuing the DELETE.
|
||||
"""
|
||||
|
||||
doc_id_by_path: NotRequired[
|
||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
|
|
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||
|
||||
dirty_path_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||
|
||||
The persistence body coalesces multiple writes/edits to the same path
|
||||
into one snapshot per turn. This map captures the most-recent
|
||||
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||
to the latest action_id (the one the user is most likely to revert).
|
||||
"""
|
||||
|
||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||
|
||||
|
|
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
|
|||
__all__ = [
|
||||
"KbAnonDoc",
|
||||
"KbPriorityEntry",
|
||||
"PendingDelete",
|
||||
"PendingMove",
|
||||
"SurfSenseFilesystemState",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
|
|
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
result=result,
|
||||
)
|
||||
|
||||
tool_call_id = _resolve_tool_call_id(request)
|
||||
chat_turn_id = _resolve_chat_turn_id(request)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
turn_id=_resolve_turn_id(request),
|
||||
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||
# kept for one release for safe rollback. New consumers
|
||||
# should read ``tool_call_id`` directly.
|
||||
turn_id=tool_call_id,
|
||||
tool_call_id=tool_call_id,
|
||||
chat_turn_id=chat_turn_id,
|
||||
message_id=_resolve_message_id(request),
|
||||
tool_name=tool_name,
|
||||
args=args_payload,
|
||||
|
|
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
async with shielded_async_session() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
row_id = int(row.id) if row.id is not None else None
|
||||
row_created_at = row.created_at
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"ActionLogMiddleware failed to persist action log row",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Surface a side-channel SSE event so the chat tool card can
|
||||
# render a Revert button immediately after the row is durable.
|
||||
# ``stream_new_chat`` translates this into a
|
||||
# ``data-action-log`` SSE event. We DO NOT include the
|
||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"action_log",
|
||||
{
|
||||
"id": row_id,
|
||||
"lc_tool_call_id": tool_call_id,
|
||||
"chat_turn_id": chat_turn_id,
|
||||
"tool_name": tool_name,
|
||||
"reversible": bool(reversible),
|
||||
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||
"created_at": row_created_at.isoformat()
|
||||
if row_created_at
|
||||
else None,
|
||||
"error": error_payload is not None,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ActionLogMiddleware failed to dispatch action_log event",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _render_reverse(
|
||||
self,
|
||||
|
|
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
|||
}
|
||||
|
||||
|
||||
def _resolve_turn_id(request: Any) -> str | None:
|
||||
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||
try:
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
|
|
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
# Deprecated alias kept for one release. Old callers and tests treated
|
||||
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||
_resolve_turn_id = _resolve_tool_call_id
|
||||
|
||||
|
||||
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||
|
||||
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||
"""
|
||||
try:
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
configurable = config.get("configurable")
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
value = configurable.get("turn_id")
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_turn_id(request)
|
||||
return _resolve_tool_call_id(request)
|
||||
|
||||
|
||||
def _resolve_result_id(result: Any) -> str | None:
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
|
|||
- cd(path): change the current working directory.
|
||||
- pwd(): print the current working directory.
|
||||
- move_file(source, dest): move/rename a file under `/documents/`.
|
||||
- rm(path): delete a single file under `/documents/` (no `-r`).
|
||||
- rmdir(path): delete an empty directory under `/documents/`.
|
||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||
|
||||
## Persistence Rules
|
||||
|
|
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
|
|||
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
||||
prefix for any scratch/working content you do NOT want saved.
|
||||
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
||||
- mkdir/move_file are staged this turn and committed at end of turn alongside
|
||||
any new/edited documents.
|
||||
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
|
||||
turn alongside any new/edited documents. Snapshot/revert is enabled
|
||||
for every destructive operation when action logging is on.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
|
|
@ -176,6 +179,8 @@ directory (`cwd`).
|
|||
- cd(path): change the current working directory.
|
||||
- pwd(): print the current working directory.
|
||||
- move_file(source, dest): move/rename a file.
|
||||
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
|
||||
- rmdir(path): delete an empty directory from disk. NOT reversible.
|
||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||
|
||||
## Workflow Tips
|
||||
|
|
@ -184,6 +189,8 @@ directory (`cwd`).
|
|||
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
||||
brute-force directory traversal.
|
||||
- Cross-mount moves are not supported.
|
||||
- Desktop deletes hit disk immediately and cannot be undone via the
|
||||
agent's revert flow — confirm before calling `rm`/`rmdir`.
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -355,6 +362,42 @@ Notes:
|
|||
- Parent folders are created as needed.
|
||||
"""
|
||||
|
||||
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
|
||||
|
||||
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
|
||||
for end-of-turn commit; the row is removed only after the agent's turn
|
||||
finishes successfully.
|
||||
|
||||
Args:
|
||||
- path: absolute or relative file path. Cannot point at a directory — use
|
||||
`rmdir` for empty folders. Cannot target the root or `/documents`.
|
||||
|
||||
Notes:
|
||||
- The action is reversible via the per-action revert flow when action
|
||||
logging is enabled.
|
||||
- The anonymous uploaded document is read-only and cannot be deleted.
|
||||
"""
|
||||
|
||||
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
|
||||
|
||||
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||
deletion (`rm -r`) is intentionally NOT supported — clear contents with
|
||||
`rm` first.
|
||||
|
||||
Args:
|
||||
- path: absolute or relative directory path. Cannot target the root,
|
||||
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
|
||||
move out first).
|
||||
|
||||
Notes:
|
||||
- Emptiness is evaluated against the post-staged view, so a same-turn
|
||||
`rm /a/x.md` followed by `rmdir /a` is fine.
|
||||
- If the directory was added in this same turn via `mkdir` and never
|
||||
committed, the staged mkdir is dropped instead of issuing a delete.
|
||||
- The action is reversible via the per-action revert flow when action
|
||||
logging is enabled.
|
||||
"""
|
||||
|
||||
# --- desktop-only ----------------------------------------------------------
|
||||
|
||||
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||
|
|
@ -421,6 +464,28 @@ Notes:
|
|||
- Parent folders are created as needed.
|
||||
"""
|
||||
|
||||
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
|
||||
|
||||
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
|
||||
disk immediately. Desktop deletes are NOT reversible via the agent's
|
||||
revert flow.
|
||||
|
||||
Args:
|
||||
- path: absolute mount-prefixed file path. Cannot point at a directory —
|
||||
use `rmdir` for empty folders.
|
||||
"""
|
||||
|
||||
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
|
||||
|
||||
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||
deletion is NOT supported. The deletion hits disk immediately and is
|
||||
NOT reversible via the agent's revert flow.
|
||||
|
||||
Args:
|
||||
- path: absolute mount-prefixed directory path. Cannot target the mount
|
||||
root or any directory containing files/subfolders.
|
||||
"""
|
||||
|
||||
|
||||
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||
"""Pick the active-mode description for every filesystem tool."""
|
||||
|
|
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
|||
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
|
||||
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
|
||||
}
|
||||
return {
|
||||
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
||||
|
|
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
|||
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
|
||||
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
|
|||
return path.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
def _is_ancestor_of(candidate: str, target: str) -> bool:
|
||||
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
|
||||
|
||||
``target`` itself is NOT considered an ancestor (use equality for that).
|
||||
Both paths are assumed to be canonicalised, absolute, and free of
|
||||
trailing slashes (except the root ``/``).
|
||||
"""
|
||||
if not candidate.startswith("/") or not target.startswith("/"):
|
||||
return False
|
||||
if candidate == target:
|
||||
return False
|
||||
prefix = candidate.rstrip("/") + "/"
|
||||
return target.startswith(prefix)
|
||||
|
||||
|
||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
||||
|
||||
|
|
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
self.tools.append(self._create_cd_tool())
|
||||
self.tools.append(self._create_pwd_tool())
|
||||
self.tools.append(self._create_move_file_tool())
|
||||
self.tools.append(self._create_rm_tool())
|
||||
self.tools.append(self._create_rmdir_tool())
|
||||
self.tools.append(self._create_list_tree_tool())
|
||||
if self._sandbox_available:
|
||||
self.tools.append(self._create_execute_code_tool())
|
||||
|
|
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
}
|
||||
if self._is_cloud():
|
||||
update["dirty_paths"] = [path]
|
||||
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||
return Command(update=update)
|
||||
|
||||
def sync_write_file(
|
||||
|
|
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
}
|
||||
if self._is_cloud():
|
||||
update["dirty_paths"] = [path]
|
||||
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||
if doc_id_to_attach is not None:
|
||||
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
||||
return Command(update=update)
|
||||
|
|
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
return Command(
|
||||
update={
|
||||
"staged_dirs": [validated],
|
||||
"staged_dir_tool_calls": {
|
||||
validated: runtime.tool_call_id,
|
||||
},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
|
|
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
||||
update: dict[str, Any] = {
|
||||
"files": files_update,
|
||||
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
|
||||
"pending_moves": [
|
||||
{
|
||||
"source": source,
|
||||
"dest": dest,
|
||||
"overwrite": False,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
|
|
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
update["dirty_paths"] = new_dirty
|
||||
return Command(update=update)
|
||||
|
||||
# ------------------------------------------------------------------ tool: rm
|
||||
|
||||
def _create_rm_tool(self) -> BaseTool:
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
async def async_rm(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path to the file to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
if not path or not path.strip():
|
||||
return "Error: path is required."
|
||||
|
||||
target = self._resolve_relative(path, runtime)
|
||||
try:
|
||||
validated = validate_path(target)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
if self._is_cloud():
|
||||
if validated in ("/", DOCUMENTS_ROOT):
|
||||
return f"Error: refusing to rm '{validated}'."
|
||||
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud rm must target a path under /documents/ "
|
||||
f"(got '{validated}')."
|
||||
)
|
||||
|
||||
anon = runtime.state.get("kb_anon_doc") or {}
|
||||
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
|
||||
return "Error: the anonymous uploaded document is read-only."
|
||||
|
||||
# Refuse if the path looks like a directory.
|
||||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
if validated in staged_dirs:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
)
|
||||
pending_dir_deletes = list(
|
||||
runtime.state.get("pending_dir_deletes") or []
|
||||
)
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
):
|
||||
return f"Error: '{validated}' is already queued for rmdir."
|
||||
|
||||
backend = self._get_backend(runtime)
|
||||
if isinstance(backend, KBPostgresBackend):
|
||||
# Detect "is a directory" via `ls`: if the path lists
|
||||
# children we know it's a folder. Otherwise we still
|
||||
# need to confirm it's a real file before staging.
|
||||
children = await backend.als_info(validated)
|
||||
if children:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
)
|
||||
|
||||
# Already queued for delete this turn?
|
||||
pending_deletes = list(runtime.state.get("pending_deletes") or [])
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_deletes
|
||||
):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
# Resolve doc_id (best-effort): file in state or DB.
|
||||
files_state = runtime.state.get("files") or {}
|
||||
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
|
||||
resolved_doc_id: int | None = doc_id_by_path.get(validated)
|
||||
if (
|
||||
validated not in files_state
|
||||
and resolved_doc_id is None
|
||||
and isinstance(backend, KBPostgresBackend)
|
||||
):
|
||||
loaded = await backend._load_file_data(validated)
|
||||
if loaded is None:
|
||||
return f"Error: file '{validated}' not found."
|
||||
_, resolved_doc_id = loaded
|
||||
|
||||
files_update: dict[str, Any] = {validated: None}
|
||||
update: dict[str, Any] = {
|
||||
"pending_deletes": [
|
||||
{
|
||||
"path": validated,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"files": files_update,
|
||||
"doc_id_by_path": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged delete of '{validated}' (will commit at "
|
||||
"end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
# Drop the path from dirty_paths so a same-turn write+rm
|
||||
# doesn't recreate the doc at commit time.
|
||||
dirty_paths = list(runtime.state.get("dirty_paths") or [])
|
||||
if validated in dirty_paths:
|
||||
new_dirty: list[Any] = [_CLEAR]
|
||||
for entry in dirty_paths:
|
||||
if entry != validated:
|
||||
new_dirty.append(entry)
|
||||
update["dirty_paths"] = new_dirty
|
||||
update["dirty_path_tool_calls"] = {validated: None}
|
||||
|
||||
return Command(update=update)
|
||||
|
||||
# Desktop mode — hit disk immediately.
|
||||
backend = self._get_backend(runtime)
|
||||
adelete = getattr(backend, "adelete_file", None)
|
||||
if not callable(adelete):
|
||||
return "Error: rm is not supported by the active backend."
|
||||
res: WriteResult = await adelete(validated)
|
||||
if res.error:
|
||||
return res.error
|
||||
update_desktop: dict[str, Any] = {
|
||||
"files": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Deleted file '{res.path or validated}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
return Command(update=update_desktop)
|
||||
|
||||
def sync_rm(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path to the file to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
return self._run_async_blocking(async_rm(path, runtime))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="rm",
|
||||
description=tool_description,
|
||||
func=sync_rm,
|
||||
coroutine=async_rm,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ tool: rmdir
|
||||
|
||||
def _create_rmdir_tool(self) -> BaseTool:
|
||||
tool_description = (
|
||||
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
async def async_rmdir(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path of the empty directory to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
if not path or not path.strip():
|
||||
return "Error: path is required."
|
||||
|
||||
target = self._resolve_relative(path, runtime)
|
||||
try:
|
||||
validated = validate_path(target)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
if self._is_cloud():
|
||||
if validated in ("/", DOCUMENTS_ROOT):
|
||||
return f"Error: refusing to rmdir '{validated}'."
|
||||
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud rmdir must target a path under /documents/ "
|
||||
f"(got '{validated}')."
|
||||
)
|
||||
|
||||
cwd = self._current_cwd(runtime)
|
||||
if validated == cwd or _is_ancestor_of(validated, cwd):
|
||||
return (
|
||||
f"Error: cannot rmdir '{validated}' because the current "
|
||||
"cwd is at or under it. cd out first."
|
||||
)
|
||||
|
||||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
pending_dir_deletes = list(
|
||||
runtime.state.get("pending_dir_deletes") or []
|
||||
)
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
backend = self._get_backend(runtime)
|
||||
|
||||
# The path must currently exist either in DB folder paths or
|
||||
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
|
||||
# already accounts for pending deletes/moves) to evaluate
|
||||
# both existence and emptiness against the post-staged view.
|
||||
exists_in_staged = validated in staged_dirs
|
||||
children: list[Any] = []
|
||||
if isinstance(backend, KBPostgresBackend):
|
||||
children = list(await backend.als_info(validated))
|
||||
|
||||
# Detect "is a file" — if als_info returns no children but
|
||||
# the path is actually a file, we should reject. We use
|
||||
# _load_file_data to disambiguate file vs missing folder.
|
||||
if (
|
||||
isinstance(backend, KBPostgresBackend)
|
||||
and not children
|
||||
and not exists_in_staged
|
||||
):
|
||||
loaded = await backend._load_file_data(validated)
|
||||
if loaded is not None:
|
||||
return (
|
||||
f"Error: '{validated}' is a file. Use rm to delete files."
|
||||
)
|
||||
# Confirm folder exists in DB by checking the parent listing.
|
||||
parent = posixpath.dirname(validated) or "/"
|
||||
parent_listing = await backend.als_info(parent)
|
||||
parent_has_dir = any(
|
||||
info.get("path") == validated and info.get("is_dir")
|
||||
for info in parent_listing
|
||||
)
|
||||
if not parent_has_dir:
|
||||
return f"Error: directory '{validated}' not found."
|
||||
|
||||
if children:
|
||||
return (
|
||||
f"Error: directory '{validated}' is not empty. "
|
||||
"Remove contents first."
|
||||
)
|
||||
|
||||
# Same-turn mkdir un-stage: drop the staged_dirs entry
|
||||
# entirely and skip queuing a DB delete (nothing was ever
|
||||
# committed).
|
||||
if exists_in_staged:
|
||||
rest = [d for d in staged_dirs if d != validated]
|
||||
return Command(
|
||||
update={
|
||||
"staged_dirs": [_CLEAR, *rest],
|
||||
"staged_dir_tool_calls": {validated: None},
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(f"Un-staged directory '{validated}'."),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"pending_dir_deletes": [
|
||||
{
|
||||
"path": validated,
|
||||
"tool_call_id": runtime.tool_call_id,
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged rmdir of '{validated}' (will commit "
|
||||
"at end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Desktop mode — hit disk immediately.
|
||||
backend = self._get_backend(runtime)
|
||||
armdir = getattr(backend, "armdir", None)
|
||||
if not callable(armdir):
|
||||
return "Error: rmdir is not supported by the active backend."
|
||||
res: WriteResult = await armdir(validated)
|
||||
if res.error:
|
||||
return res.error
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Deleted directory '{res.path or validated}'",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def sync_rmdir(
|
||||
path: Annotated[
|
||||
str,
|
||||
"Absolute or relative path of the empty directory to delete.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
return self._run_async_blocking(async_rmdir(path, runtime))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="rmdir",
|
||||
description=tool_description,
|
||||
func=sync_rmdir,
|
||||
coroutine=async_rmdir,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ tool: list_tree
|
||||
|
||||
def _create_list_tree_tool(self) -> BaseTool:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_moves") or [])
|
||||
|
||||
def _pending_deletes(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_deletes") or [])
|
||||
|
||||
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
|
||||
return list(self.state.get("pending_dir_deletes") or [])
|
||||
|
||||
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||
anon = self.state.get("kb_anon_doc")
|
||||
return anon if isinstance(anon, dict) else None
|
||||
|
|
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
|
|||
return path
|
||||
return path.rstrip("/") if path != "/" else path
|
||||
|
||||
def _moved_view_paths(
|
||||
def _pending_filesystem_view(
|
||||
self,
|
||||
existing: dict[str, dict[str, Any]],
|
||||
) -> tuple[set[str], dict[str, str]]:
|
||||
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
||||
) -> tuple[set[str], dict[str, str], set[str]]:
|
||||
"""Compute removed/aliased/dir-suppressed paths from staged ops.
|
||||
|
||||
Removed paths should disappear from listings; ``alias[source] = dest``
|
||||
means a virtual entry should appear at ``dest`` even if no DB row is
|
||||
yet there.
|
||||
Returns ``(removed, alias, deleted_dirs)`` where:
|
||||
|
||||
* ``removed`` — paths to drop from listings (sources of pending moves
|
||||
AND paths queued for ``rm``).
|
||||
* ``alias`` — ``{source: dest}`` for pending moves; the dest should
|
||||
appear as a virtual entry even when no DB row is at that path yet.
|
||||
* ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire
|
||||
subtree (descendants) is suppressed from listings/glob/grep.
|
||||
|
||||
Entries in ``existing`` (the ``files`` state cache) keyed by a
|
||||
removed path are popped so a same-turn delete-after-write doesn't
|
||||
leave a stale virtual file in listings.
|
||||
"""
|
||||
removed: set[str] = set()
|
||||
alias: dict[str, str] = {}
|
||||
deleted_dirs: set[str] = set()
|
||||
for move in self._pending_moves():
|
||||
src = move.get("source")
|
||||
dst = move.get("dest")
|
||||
|
|
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
|
|||
removed.add(src)
|
||||
alias[src] = dst
|
||||
existing.pop(src, None)
|
||||
return removed, alias
|
||||
for entry in self._pending_deletes():
|
||||
path = entry.get("path") if isinstance(entry, dict) else None
|
||||
if not path:
|
||||
continue
|
||||
removed.add(path)
|
||||
existing.pop(path, None)
|
||||
for entry in self._pending_dir_deletes():
|
||||
path = entry.get("path") if isinstance(entry, dict) else None
|
||||
if not path:
|
||||
continue
|
||||
deleted_dirs.add(path)
|
||||
return removed, alias, deleted_dirs
|
||||
|
||||
@staticmethod
|
||||
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
|
||||
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
|
||||
return any(path == d or _is_under(path, d) for d in deleted_dirs)
|
||||
|
||||
# ------------------------------------------------------------------ ls/read
|
||||
|
||||
|
|
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
seen.add(anon_path)
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, moved_alias = self._moved_view_paths(files)
|
||||
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
|
||||
|
||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||
try:
|
||||
|
|
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
|
||||
for info in db_infos:
|
||||
p = info.get("path", "")
|
||||
if not p or p in seen or p in moved_removed:
|
||||
if (
|
||||
not p
|
||||
or p in seen
|
||||
or p in moved_removed
|
||||
or self._is_dir_suppressed(p, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
infos.append(info)
|
||||
seen.add(p)
|
||||
|
|
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
if src not in seen:
|
||||
if not _is_under(dst, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(dst, deleted_dirs):
|
||||
continue
|
||||
rel = (
|
||||
dst[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
|
|
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
continue
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||
continue
|
||||
rel = (
|
||||
staged[len(normalized) :].lstrip("/")
|
||||
if normalized != "/"
|
||||
|
|
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for sub in sorted(subdir_paths):
|
||||
if sub in seen:
|
||||
continue
|
||||
if self._is_dir_suppressed(sub, deleted_dirs):
|
||||
continue
|
||||
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||
seen.add(sub)
|
||||
|
||||
for path_key, fd in files.items():
|
||||
if not isinstance(path_key, str) or path_key in seen:
|
||||
continue
|
||||
# Tombstones (None values) are deletion markers from `rm`. The
|
||||
# deepagents reducer normally pops them, but a stale tombstone
|
||||
# surviving a checkpoint must NOT be reported as a child here —
|
||||
# otherwise rmdir mistakenly sees the deleted file as content.
|
||||
if fd is None:
|
||||
continue
|
||||
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||
continue
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if normalized == "/":
|
||||
rel = path_key.lstrip("/")
|
||||
else:
|
||||
|
|
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
seen: set[str] = set()
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
regex = re.compile(fnmatch.translate(pattern))
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
|
|
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
|
|||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in seen or candidate in moved_removed:
|
||||
if (
|
||||
candidate in seen
|
||||
or candidate in moved_removed
|
||||
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
|
|||
matches: list[GrepMatch] = []
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||
for path_key, fd in files.items():
|
||||
if path_key in moved_removed:
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
|
|
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
|
|||
)
|
||||
for doc_id, chunk_id, content in chunk_buffer:
|
||||
candidate = doc_id_to_path.get(doc_id)
|
||||
if not candidate or candidate in moved_removed:
|
||||
if (
|
||||
not candidate
|
||||
or candidate in moved_removed
|
||||
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
return {"entries": [], "truncated": False}
|
||||
|
||||
files = self._state_files()
|
||||
moved_removed, _ = self._moved_view_paths(files)
|
||||
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||
anon = self._kb_anon_doc()
|
||||
anon_path = str(anon.get("path") or "") if anon else ""
|
||||
|
||||
|
|
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||
if not _is_under(fpath, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(fpath, deleted_dirs):
|
||||
continue
|
||||
depth = _depth_of(fpath)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
|
|
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
|
|||
for staged in self._staged_dirs():
|
||||
if not _is_under(staged, normalized):
|
||||
continue
|
||||
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||
continue
|
||||
depth = _depth_of(staged)
|
||||
if max_depth is not None and depth > max_depth:
|
||||
continue
|
||||
|
|
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
|
|||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
if candidate in moved_removed:
|
||||
if candidate in moved_removed or self._is_dir_suppressed(
|
||||
candidate, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if not _is_under(candidate, normalized):
|
||||
continue
|
||||
|
|
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
|
|||
continue
|
||||
if not _is_under(path_key, normalized):
|
||||
continue
|
||||
if path_key in moved_removed or self._is_dir_suppressed(
|
||||
path_key, deleted_dirs
|
||||
):
|
||||
continue
|
||||
if any(e["path"] == path_key for e in entries):
|
||||
continue
|
||||
if not (
|
||||
|
|
|
|||
|
|
@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||
# infer emptiness from indentation alone.
|
||||
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
|
|
@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
lines.append(f"{indent}{display}/")
|
||||
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||
lines.append(f"{indent}{display}/ (empty)")
|
||||
else:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
|
|
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
@staticmethod
|
||||
def _compute_non_empty_folders(
|
||||
folder_paths: list[str], doc_paths: list[str]
|
||||
) -> set[str]:
|
||||
"""Return the set of folder paths that contain at least one descendant.
|
||||
|
||||
A folder is "non-empty" if any document path or any other folder path
|
||||
is strictly under it. Documents propagate emptiness up to every
|
||||
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||
"""
|
||||
non_empty: set[str] = set()
|
||||
folder_set = set(folder_paths)
|
||||
|
||||
for doc_path in doc_paths:
|
||||
parent = doc_path.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT:
|
||||
if parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
for child in folder_paths:
|
||||
parent = child.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
return non_empty
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -360,6 +360,74 @@ class LocalFolderBackend:
|
|||
self.move, source_path, destination_path, overwrite
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
"""Hard-delete a single file under root.
|
||||
|
||||
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||
out of scope.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||
with self._lock_for(file_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||
if path.is_dir():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: '{file_path}' is a directory. "
|
||||
"Use rmdir for empty directories."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(
|
||||
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||
)
|
||||
return WriteResult(path=file_path, files_update=None)
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
"""Hard-delete an empty directory under root.
|
||||
|
||||
Refuses files, root, missing paths, and non-empty directories.
|
||||
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||
clearer for the agent.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(dir_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||
with self._lock_for(dir_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||
if not path.is_dir():
|
||||
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||
try:
|
||||
next(path.iterdir())
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: directory '{dir_path}' is not empty. "
|
||||
"Remove its contents first."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.rmdir(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||
return WriteResult(path=dir_path, files_update=None)
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
|
|||
|
|
@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
|
|||
overwrite,
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(dir_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
if local_path == "/":
|
||||
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
|
|||
|
|
@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
|
|||
return {
|
||||
"cwd": "/documents",
|
||||
"staged_dirs": [],
|
||||
"staged_dir_tool_calls": {},
|
||||
"pending_moves": [],
|
||||
"pending_deletes": [],
|
||||
"pending_dir_deletes": [],
|
||||
"doc_id_by_path": {},
|
||||
"dirty_paths": [],
|
||||
"dirty_path_tool_calls": {},
|
||||
"kb_priority": [],
|
||||
"kb_matched_chunk_ids": {},
|
||||
"kb_anon_doc": None,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
|||
"write_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"update_memory",
|
||||
"update_memory_team",
|
||||
"update_memory_private",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,35 @@ from langgraph.types import interrupt
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tools that mirror the safety profile of ``write_file`` against the
|
||||
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
||||
# with no external visibility (drafts aren't sent; new files aren't shared
|
||||
# unless the user shares them later). These are auto-approved by default
|
||||
# so the agent can compose drafts and seed scratch files without a popup
|
||||
# on every call.
|
||||
#
|
||||
# Members of this set still call ``request_approval`` exactly as before;
|
||||
# the function returns immediately with ``decision_type="auto_approved"``
|
||||
# and the original params untouched. This preserves the call-site shape
|
||||
# (logging, metadata fetching, account fallbacks) so the only behavior
|
||||
# change is "no interrupt fires".
|
||||
#
|
||||
# To re-enable prompting, the future per-search-space rules table
|
||||
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
|
||||
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
"create_dropbox_file",
|
||||
"create_onedrive_file",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class HITLResult:
|
||||
"""Outcome of a human-in-the-loop approval request."""
|
||||
|
|
@ -119,6 +148,19 @@ def request_approval(
|
|||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||
|
||||
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||
# Default policy: low-stakes creation tools (drafts + new-file
|
||||
# creates) skip HITL because they're as recoverable as a local
|
||||
# ``write_file`` against the SurfSense KB. The user can still
|
||||
# delete the artifact in <30s if it's wrong.
|
||||
logger.info(
|
||||
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||
tool_name,
|
||||
)
|
||||
return HITLResult(
|
||||
rejected=False, decision_type="auto_approved", params=dict(params)
|
||||
)
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": action_type,
|
||||
|
|
|
|||
|
|
@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
|
||||
# Per-turn correlation id sourced from ``configurable.turn_id`` at
|
||||
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
|
||||
# predate the column. Used by C1's edit-from-arbitrary-position to map
|
||||
# a message back to the LangGraph checkpoint that produced its turn.
|
||||
turn_id = Column(String(64), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
thread = relationship("NewChatThread", back_populates="messages")
|
||||
author = relationship("User")
|
||||
|
|
@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel):
|
|||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
|
||||
# been renamed to ``tool_call_id`` (with a parallel column kept for one
|
||||
# release for back-compat). The real chat-turn id lives in
|
||||
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
|
||||
turn_id = Column(String(64), nullable=True, index=True)
|
||||
tool_call_id = Column(String(64), nullable=True, index=True)
|
||||
chat_turn_id = Column(String(64), nullable=True, index=True)
|
||||
message_id = Column(String(128), nullable=True, index=True)
|
||||
tool_name = Column(String(255), nullable=False, index=True)
|
||||
args = Column(JSONB, nullable=True)
|
||||
|
|
@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel):
|
|||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
||||
# Partial unique index enforces "at most one revert per
|
||||
# original action". Created in migration 137 with
|
||||
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
|
||||
# (the vast majority) are unaffected and NULLs don't collide.
|
||||
Index(
|
||||
"ux_agent_action_log_reverse_of",
|
||||
"reverse_of",
|
||||
unique=True,
|
||||
postgresql_where=text("reverse_of IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel):
|
|||
|
||||
__tablename__ = "document_revisions"
|
||||
|
||||
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||
# hard-delete it describes — without that, ``rm`` would wipe the row
|
||||
# we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
ForeignKey("documents.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
|
|
@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel):
|
|||
|
||||
__tablename__ = "folder_revisions"
|
||||
|
||||
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||
# hard-delete it describes — without that, ``rmdir`` would wipe the
|
||||
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||
folder_id = Column(
|
||||
Integer,
|
||||
ForeignKey("folders.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
ForeignKey("folders.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
|
|||
reverse_of: int | None
|
||||
reverted_by_action_id: int | None
|
||||
is_revert_action: bool
|
||||
# Correlation ids added in migration 135. ``tool_call_id`` is the
|
||||
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
|
||||
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
|
||||
# from ``configurable.turn_id`` (used by the
|
||||
# ``revert-turn/{chat_turn_id}`` endpoint).
|
||||
tool_call_id: str | None = None
|
||||
chat_turn_id: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
|
@ -172,6 +179,8 @@ async def list_thread_actions(
|
|||
reverse_of=row.reverse_of,
|
||||
reverted_by_action_id=revert_map.get(row.id),
|
||||
is_revert_action=row.reverse_of is not None,
|
||||
tool_call_id=row.tool_call_id,
|
||||
chat_turn_id=row.chat_turn_id,
|
||||
created_at=row.created_at,
|
||||
)
|
||||
for row in rows
|
||||
|
|
|
|||
|
|
@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
|
|||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||
5. Idempotent on retries: if the same action is reverted twice the second
|
||||
call returns 409 ``"already reverted"``.
|
||||
|
||||
This module also hosts the per-turn batch endpoint
|
||||
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
|
||||
walks every reversible action emitted during a chat turn in reverse
|
||||
``created_at`` order and reverts each independently. Partial success is the
|
||||
common case — the response always contains a per-action result list and a
|
||||
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
|
||||
whole-batch 4xx.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
|
|
@ -97,6 +108,16 @@ async def revert_agent_action(
|
|||
action=action,
|
||||
requester_user_id=str(user.id) if user is not None else None,
|
||||
)
|
||||
except IntegrityError:
|
||||
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
|
||||
# a concurrent revert. Translate to the existing 409 "already
|
||||
# reverted" contract so racing clients see consistent
|
||||
# behaviour with the pre-flight TOCTOU check above.
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This action has already been reverted.",
|
||||
) from None
|
||||
except Exception as err:
|
||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||
await session.rollback()
|
||||
|
|
@ -105,7 +126,16 @@ async def revert_agent_action(
|
|||
) from err
|
||||
|
||||
if outcome.status == "ok":
|
||||
await session.commit()
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# Race lost on commit (constraint enforced at flush in some
|
||||
# configs but at commit in others — defensive).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This action has already been reverted.",
|
||||
) from None
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": outcome.message,
|
||||
|
|
@ -122,3 +152,357 @@ async def revert_agent_action(
|
|||
raise HTTPException(status_code=501, detail=outcome.message)
|
||||
# not_reversible
|
||||
raise HTTPException(status_code=409, detail=outcome.message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-turn revert batch endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
PerActionStatus = Literal[
|
||||
"reverted",
|
||||
"already_reverted",
|
||||
"not_reversible",
|
||||
"permission_denied",
|
||||
"failed",
|
||||
"skipped",
|
||||
]
|
||||
|
||||
|
||||
class RevertTurnActionResult(BaseModel):
|
||||
"""Per-action outcome inside a ``revert-turn`` batch response."""
|
||||
|
||||
action_id: int
|
||||
tool_name: str
|
||||
status: PerActionStatus
|
||||
message: str | None = None
|
||||
new_action_id: int | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class RevertTurnResponse(BaseModel):
|
||||
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||
|
||||
``status`` is ``"ok"`` only when every reversible row succeeded. Any
|
||||
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
|
||||
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
|
||||
``results`` list — callers should treat that as a no-op.
|
||||
|
||||
Counter invariant:
|
||||
``total == reverted + already_reverted + not_reversible
|
||||
+ permission_denied + failed + skipped``
|
||||
|
||||
Frontend toasts and the ``RevertTurnButton`` summary rely on this
|
||||
invariant to display "X of Y reverted, Z could not be undone" without
|
||||
silently dropping ``permission_denied`` or ``skipped`` rows.
|
||||
"""
|
||||
|
||||
status: Literal["ok", "partial"]
|
||||
chat_turn_id: str
|
||||
total: int
|
||||
reverted: int
|
||||
already_reverted: int
|
||||
not_reversible: int
|
||||
permission_denied: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
results: list[RevertTurnActionResult]
|
||||
|
||||
|
||||
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
|
||||
if outcome.status == "ok":
|
||||
return "reverted"
|
||||
if outcome.status == "permission_denied":
|
||||
return "permission_denied"
|
||||
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
|
||||
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
|
||||
# — they share the same UX (this row cannot be undone) and only the
|
||||
# ``message`` differs.
|
||||
return "not_reversible"
|
||||
|
||||
|
||||
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
|
||||
"""Return the id of an existing successful revert row, if any.
|
||||
|
||||
Single-action variant — kept for the post-IntegrityError lookup
|
||||
path where we already know we lost a race for one specific id.
|
||||
"""
|
||||
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _was_already_reverted_batch(
|
||||
session: AsyncSession, *, action_ids: list[int]
|
||||
) -> dict[int, int]:
|
||||
"""Batch idempotency probe for the revert-turn loop.
|
||||
|
||||
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
|
||||
(one per row in the turn) with a single ``SELECT id, reverse_of
|
||||
WHERE reverse_of IN (:ids)``. The route still iterates rows in
|
||||
reverse-chronological order, but the membership check is O(1) per
|
||||
iteration after this query. For a turn with 30 actions that's 30
|
||||
fewer round-trips through asyncpg + a smaller transaction footprint.
|
||||
|
||||
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
|
||||
keys mean "not yet reverted" — callers should treat them as
|
||||
eligible for revert.
|
||||
"""
|
||||
if not action_ids:
|
||||
return {}
|
||||
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
|
||||
AgentActionLog.reverse_of.in_(action_ids)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return {
|
||||
original_id: revert_id
|
||||
for revert_id, original_id in result.all()
|
||||
if original_id is not None
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
|
||||
response_model=RevertTurnResponse,
|
||||
)
|
||||
async def revert_agent_turn(
|
||||
thread_id: int,
|
||||
chat_turn_id: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> RevertTurnResponse:
|
||||
"""Revert every reversible action emitted during ``chat_turn_id``.
|
||||
|
||||
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
|
||||
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
|
||||
folder) unwind in the right sequence. Each action is reverted in its
|
||||
own SAVEPOINT so a single failure does not poison the batch.
|
||||
|
||||
Partial success is intentional and returned with HTTP 200. Callers
|
||||
must inspect ``results[*].status`` to find rows that need attention.
|
||||
"""
|
||||
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Revert is not available on this deployment yet. The route "
|
||||
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
|
||||
"enable it."
|
||||
),
|
||||
)
|
||||
|
||||
thread = await load_thread(session, thread_id=thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||
|
||||
# Reverse-chronological so the latest mutation in the turn unwinds
|
||||
# first. ``id.desc()`` is the deterministic tiebreaker for actions
|
||||
# written in the same millisecond.
|
||||
rows_stmt = (
|
||||
select(AgentActionLog)
|
||||
.where(
|
||||
AgentActionLog.thread_id == thread_id,
|
||||
AgentActionLog.chat_turn_id == chat_turn_id,
|
||||
)
|
||||
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
|
||||
)
|
||||
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||
|
||||
requester_user_id = str(user.id) if user is not None else None
|
||||
results: list[RevertTurnActionResult] = []
|
||||
# Counters MUST be exhaustive so the response invariant
|
||||
# ``total == sum(counters)`` always holds. Frontend toasts and
|
||||
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
|
||||
counts: dict[str, int] = {
|
||||
"reverted": 0,
|
||||
"already_reverted": 0,
|
||||
"not_reversible": 0,
|
||||
"permission_denied": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
# Single batched idempotency probe replaces the previous per-row
|
||||
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
|
||||
# the original-action ids (skip rows that are themselves
|
||||
# reverts).
|
||||
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||
already_reverted_map = await _was_already_reverted_batch(
|
||||
session, action_ids=eligible_ids
|
||||
)
|
||||
|
||||
for action in rows:
|
||||
# Skip rows that ARE reverts of an earlier action — reverting a
|
||||
# revert is meaningless inside a batch (the user wants to wipe
|
||||
# the original effects, not chase tail).
|
||||
if action.reverse_of is not None:
|
||||
counts["skipped"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="skipped",
|
||||
message="Row is itself a revert action; skipped.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Idempotency: surface "already_reverted" instead of failing.
|
||||
existing_revert_id = already_reverted_map.get(action.id)
|
||||
if existing_revert_id is not None:
|
||||
counts["already_reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not can_revert(
|
||||
requester_user_id=requester_user_id,
|
||||
action=action,
|
||||
is_admin=False,
|
||||
):
|
||||
counts["permission_denied"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="permission_denied",
|
||||
message="You are not allowed to revert this action.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Per-row SAVEPOINT so one failed revert never poisons later
|
||||
# successful ones.
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
outcome = await revert_action(
|
||||
session,
|
||||
action=action,
|
||||
requester_user_id=requester_user_id,
|
||||
)
|
||||
if outcome.status != "ok":
|
||||
raise _OutcomeRollbackError(outcome)
|
||||
except _OutcomeRollbackError as rollback:
|
||||
outcome = rollback.outcome
|
||||
classified = _classify_outcome(outcome)
|
||||
if classified == "permission_denied":
|
||||
counts["permission_denied"] += 1
|
||||
else:
|
||||
counts["not_reversible"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status=classified,
|
||||
message=outcome.message,
|
||||
)
|
||||
)
|
||||
continue
|
||||
except IntegrityError:
|
||||
# Partial unique index caught a concurrent revert that won
|
||||
# the race against our pre-flight ``_was_already_reverted``
|
||||
# SELECT. Look up the winner so
|
||||
# we can surface its ``new_action_id`` to the client.
|
||||
existing_revert_id = await _was_already_reverted(
|
||||
session, action_id=action.id
|
||||
)
|
||||
counts["already_reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
except Exception as err: # pragma: no cover — defensive, logged
|
||||
logger.exception(
|
||||
"Unexpected revert failure inside batch for action_id=%s",
|
||||
action.id,
|
||||
)
|
||||
counts["failed"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="failed",
|
||||
error=str(err) or err.__class__.__name__,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
counts["reverted"] += 1
|
||||
results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="reverted",
|
||||
message=outcome.message,
|
||||
new_action_id=outcome.new_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Single commit at the end — successful SAVEPOINTs above already
|
||||
# released; failed ones rolled back to their savepoint. No row leaks
|
||||
# across the boundary.
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as err: # pragma: no cover — defensive
|
||||
logger.exception(
|
||||
"Final commit for revert-turn failed (thread=%s turn=%s)",
|
||||
thread_id,
|
||||
chat_turn_id,
|
||||
)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal error while finalising revert-turn batch.",
|
||||
) from err
|
||||
|
||||
has_partial = (
|
||||
counts["failed"] > 0
|
||||
or counts["not_reversible"] > 0
|
||||
or counts["permission_denied"] > 0
|
||||
)
|
||||
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
|
||||
|
||||
return RevertTurnResponse(
|
||||
status=overall_status,
|
||||
chat_turn_id=chat_turn_id,
|
||||
total=len(rows),
|
||||
reverted=counts["reverted"],
|
||||
already_reverted=counts["already_reverted"],
|
||||
not_reversible=counts["not_reversible"],
|
||||
permission_denied=counts["permission_denied"],
|
||||
failed=counts["failed"],
|
||||
skipped=counts["skipped"],
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
class _OutcomeRollbackError(Exception):
|
||||
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
|
||||
|
||||
``revert_action`` writes a new ``agent_action_log`` row only on the
|
||||
happy path, but on the failure paths it sometimes mutates the
|
||||
``DocumentRevision``/``Document`` tables before deciding the action
|
||||
is not reversible. Wrapping each call in ``begin_nested`` and raising
|
||||
this from the failure branch ensures we always discard partial
|
||||
writes for failed rows.
|
||||
"""
|
||||
|
||||
def __init__(self, outcome: RevertOutcome) -> None:
|
||||
self.outcome = outcome
|
||||
super().__init__(outcome.message)
|
||||
|
||||
|
||||
__all__ = ["router"]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -136,6 +137,260 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
turn_id: str,
|
||||
) -> str | None:
|
||||
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
|
||||
|
||||
``checkpoint_tuples`` arrives newest-first from
|
||||
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
|
||||
and remember the most recent checkpoint that does NOT belong to the
|
||||
edited turn. As soon as we cross into the edited turn (a checkpoint
|
||||
whose ``turn_id`` matches), we return the previously-tracked
|
||||
checkpoint — that's the state immediately before ``turn_id`` began.
|
||||
|
||||
The naive "newest-first, return first non-matching" approach is
|
||||
INCORRECT when later turns exist after ``turn_id``: their
|
||||
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
|
||||
returned before the real pre-turn boundary is reached.
|
||||
|
||||
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
|
||||
``configurable`` at write time) rather than ``config["configurable"]``
|
||||
so the lookup is portable across checkpointer implementations.
|
||||
|
||||
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
|
||||
the edited turn is the very first turn of the thread). Callers fall
|
||||
back to the oldest available checkpoint in that case.
|
||||
"""
|
||||
|
||||
last_pre_turn_target: str | None = None
|
||||
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
|
||||
metadata = getattr(cp_tuple, "metadata", None) or {}
|
||||
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
|
||||
if cp_turn_id == turn_id:
|
||||
# Crossed into the edited turn; the previous tracked
|
||||
# checkpoint is the rewind target. May be ``None`` if we hit
|
||||
# the edited turn on the very first iteration.
|
||||
return last_pre_turn_target
|
||||
try:
|
||||
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
return last_pre_turn_target
|
||||
|
||||
|
||||
async def _revert_turns_for_regenerate(
|
||||
*,
|
||||
thread_id: int,
|
||||
chat_turn_ids: list[str],
|
||||
requester_user_id: str,
|
||||
) -> dict:
|
||||
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
|
||||
|
||||
Runs BEFORE the regenerate stream so the frontend can surface
|
||||
partial-rollback feedback alongside the new assistant turn. Each
|
||||
turn's actions are reverted in their own SAVEPOINTs (handled
|
||||
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
|
||||
failure never poisons the batch.
|
||||
|
||||
Sequencing inside the request: revert THEN regenerate. The
|
||||
operation is NOT atomic and partial state IS surfaced — see the
|
||||
plan's "Sequencing inside the request" note.
|
||||
"""
|
||||
|
||||
from app.routes.agent_revert_route import (
|
||||
RevertTurnActionResult,
|
||||
_classify_outcome,
|
||||
_OutcomeRollbackError,
|
||||
_was_already_reverted,
|
||||
_was_already_reverted_batch,
|
||||
)
|
||||
from app.services.revert_service import (
|
||||
can_revert,
|
||||
revert_action,
|
||||
)
|
||||
|
||||
aggregated_results: list[dict] = []
|
||||
# Exhaustive counters keep the response invariant
|
||||
# ``total == sum(counters)`` true for ``data-revert-results``.
|
||||
counts = {
|
||||
"reverted": 0,
|
||||
"already_reverted": 0,
|
||||
"not_reversible": 0,
|
||||
"permission_denied": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
}
|
||||
|
||||
# Local import keeps the route module's existing imports tidy and
|
||||
# avoids a circular dependency at module-load time.
|
||||
from app.db import AgentActionLog as _AgentActionLog
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
for chat_turn_id in chat_turn_ids:
|
||||
rows_stmt = (
|
||||
select(_AgentActionLog)
|
||||
.where(
|
||||
_AgentActionLog.thread_id == thread_id,
|
||||
_AgentActionLog.chat_turn_id == chat_turn_id,
|
||||
)
|
||||
.order_by(
|
||||
_AgentActionLog.created_at.desc(),
|
||||
_AgentActionLog.id.desc(),
|
||||
)
|
||||
)
|
||||
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||
|
||||
# Batch idempotency probe across the turn (single SELECT
|
||||
# instead of one per row).
|
||||
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||
already_reverted_map = await _was_already_reverted_batch(
|
||||
session, action_ids=eligible_ids
|
||||
)
|
||||
|
||||
for action in rows:
|
||||
if action.reverse_of is not None:
|
||||
counts["skipped"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="skipped",
|
||||
message="Row is itself a revert action; skipped.",
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
existing_revert_id = already_reverted_map.get(action.id)
|
||||
if existing_revert_id is not None:
|
||||
counts["already_reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
if not can_revert(
|
||||
requester_user_id=requester_user_id,
|
||||
action=action,
|
||||
is_admin=False,
|
||||
):
|
||||
counts["permission_denied"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="permission_denied",
|
||||
message="You are not allowed to revert this action.",
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
outcome = await revert_action(
|
||||
session,
|
||||
action=action,
|
||||
requester_user_id=requester_user_id,
|
||||
)
|
||||
if outcome.status != "ok":
|
||||
raise _OutcomeRollbackError(outcome)
|
||||
except _OutcomeRollbackError as rollback:
|
||||
outcome = rollback.outcome
|
||||
classified = _classify_outcome(outcome)
|
||||
if classified == "permission_denied":
|
||||
counts["permission_denied"] += 1
|
||||
else:
|
||||
counts["not_reversible"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status=classified,
|
||||
message=outcome.message,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
except IntegrityError:
|
||||
# Concurrent revert won the race against the
|
||||
# pre-flight ``_was_already_reverted`` SELECT.
|
||||
# Surface the winning revert id so the client can
|
||||
# treat this as a successful idempotent op.
|
||||
existing_revert_id = await _was_already_reverted(
|
||||
session, action_id=action.id
|
||||
)
|
||||
counts["already_reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="already_reverted",
|
||||
new_action_id=existing_revert_id,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
except Exception as err: # pragma: no cover — defensive
|
||||
_logger.exception(
|
||||
"Unexpected revert failure during regenerate batch "
|
||||
"for action_id=%s",
|
||||
action.id,
|
||||
)
|
||||
counts["failed"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="failed",
|
||||
error=str(err) or err.__class__.__name__,
|
||||
).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
counts["reverted"] += 1
|
||||
aggregated_results.append(
|
||||
RevertTurnActionResult(
|
||||
action_id=action.id,
|
||||
tool_name=action.tool_name,
|
||||
status="reverted",
|
||||
message=outcome.message,
|
||||
new_action_id=outcome.new_action_id,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
"[regenerate-revert] Final commit failed; rolling back batch."
|
||||
)
|
||||
await session.rollback()
|
||||
|
||||
has_partial = (
|
||||
counts["failed"] > 0
|
||||
or counts["not_reversible"] > 0
|
||||
or counts["permission_denied"] > 0
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "partial" if has_partial else "ok",
|
||||
"chat_turn_ids": chat_turn_ids,
|
||||
"total": len(aggregated_results),
|
||||
"reverted": counts["reverted"],
|
||||
"already_reverted": counts["already_reverted"],
|
||||
"not_reversible": counts["not_reversible"],
|
||||
"permission_denied": counts["permission_denied"],
|
||||
"failed": counts["failed"],
|
||||
"skipped": counts["skipped"],
|
||||
"results": aggregated_results,
|
||||
}
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
|
|
@ -574,6 +829,7 @@ async def get_thread_messages(
|
|||
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
||||
if msg.token_usage
|
||||
else None,
|
||||
turn_id=msg.turn_id,
|
||||
)
|
||||
for msg in db_messages
|
||||
]
|
||||
|
|
@ -1006,12 +1262,24 @@ async def append_message(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Create message
|
||||
# Create message. ``turn_id`` is the per-turn correlation id from
|
||||
# ``configurable.turn_id`` (added in migration 136) — when the
|
||||
# client streams it back to ``appendMessage``, we persist it so
|
||||
# C1's edit-from-arbitrary-position can later map this message
|
||||
# back to the LangGraph checkpoint that produced its turn.
|
||||
raw_turn_id = raw_body.get("turn_id")
|
||||
turn_id_value = (
|
||||
str(raw_turn_id).strip()
|
||||
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
db_message = NewChatMessage(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user.id,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
|
|
@ -1050,6 +1318,7 @@ async def append_message(
|
|||
created_at=db_message.created_at,
|
||||
author_id=db_message.author_id,
|
||||
token_usage=None,
|
||||
turn_id=db_message.turn_id,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -1373,43 +1642,123 @@ async def regenerate_response(
|
|||
user_query_to_use = request.user_query
|
||||
regenerate_image_urls: list[str] = []
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage
|
||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||
# Access the checkpoint's channel_values which contains "messages"
|
||||
checkpoint_data = cp_tuple.checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
# ---------------------------------------------------------------
|
||||
# Edit-from-arbitrary-position. When the client passes
|
||||
# ``from_message_id`` we look up its persisted ``turn_id`` (added
|
||||
# in migration 136) and pick the checkpoint immediately before
|
||||
# that turn started.
|
||||
#
|
||||
# Legacy graceful-degradation contract:
|
||||
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
|
||||
# Returning 400 in that case is the wrong UX — the user is
|
||||
# editing an old message in an existing thread and just wants
|
||||
# it to work. We instead skip the checkpoint rewind (the
|
||||
# stream falls back to the latest state) and skip the revert
|
||||
# pass (no chat_turn_id available to walk). Deletion still
|
||||
# uses ``created_at``, so the messages-after-cursor slice is
|
||||
# correct on both legacy and post-136 rows.
|
||||
# ---------------------------------------------------------------
|
||||
from_message_turn_id: str | None = None
|
||||
from_message_created_at: datetime | None = None
|
||||
legacy_from_message: bool = False
|
||||
if request.from_message_id is not None:
|
||||
from_msg_row = await session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.id == request.from_message_id,
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
)
|
||||
)
|
||||
from_msg = from_msg_row.scalars().first()
|
||||
if from_msg is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="from_message_id not found in this thread.",
|
||||
)
|
||||
from_message_created_at = from_msg.created_at
|
||||
if not from_msg.turn_id:
|
||||
# Legacy row — surface the degradation in logs but let
|
||||
# the request proceed with the slice-based delete and a
|
||||
# cold-start checkpoint.
|
||||
legacy_from_message = True
|
||||
_logger.warning(
|
||||
"[regenerate] from_message_id=%s on thread=%s has no "
|
||||
"turn_id (legacy row pre-migration-136). Falling back "
|
||||
"to slice-based delete without checkpoint rewind. "
|
||||
"revert_actions=%s will be ignored.",
|
||||
request.from_message_id,
|
||||
thread_id,
|
||||
request.revert_actions,
|
||||
)
|
||||
else:
|
||||
from_message_turn_id = from_msg.turn_id
|
||||
|
||||
if state_messages:
|
||||
last_msg = state_messages[-1]
|
||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||
# This means we're at a state before the user's last message
|
||||
if not isinstance(last_msg, HumanMessage):
|
||||
# If no new user_query provided (reload), extract from a later checkpoint
|
||||
if user_query_to_use is None and i > 0:
|
||||
# Get the user query from a more recent checkpoint
|
||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||
prev_channel_values = prev_checkpoint_data.get(
|
||||
"channel_values", {}
|
||||
)
|
||||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
q, imgs = split_langchain_human_content(msg.content)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
if user_query_to_use is not None and (
|
||||
str(user_query_to_use).strip() or regenerate_image_urls
|
||||
):
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
# Walk oldest-to-newest and pick the LAST checkpoint whose
|
||||
# ``turn_id`` differs from the edited turn — that's the state
|
||||
# immediately before this turn started running. We read from
|
||||
# ``metadata`` (the durable surface) rather than
|
||||
# ``config["configurable"]`` so the lookup works across
|
||||
# checkpointer implementations.
|
||||
target_checkpoint_id = _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples,
|
||||
turn_id=from_message_turn_id,
|
||||
)
|
||||
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
|
||||
# Fall back to the oldest checkpoint — better than
|
||||
# 400ing when the agent didn't checkpoint pre-turn
|
||||
# (e.g. very first turn of the thread).
|
||||
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
break
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage.
|
||||
# We enter this branch when:
|
||||
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
|
||||
# * the client pinned ``from_message_id`` but the row is a
|
||||
# legacy pre-migration-136 row with no ``turn_id`` (we
|
||||
# downgraded to the same heuristic as a regular reload).
|
||||
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
|
||||
# — that's the C1 happy path and the heuristic below would just
|
||||
# re-derive a worse target.
|
||||
if request.from_message_id is None or legacy_from_message:
|
||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||
# Access the checkpoint's channel_values which contains "messages"
|
||||
checkpoint_data = cp_tuple.checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
|
||||
if state_messages:
|
||||
last_msg = state_messages[-1]
|
||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||
# This means we're at a state before the user's last message
|
||||
if not isinstance(last_msg, HumanMessage):
|
||||
# If no new user_query provided (reload), extract from a later checkpoint
|
||||
if user_query_to_use is None and i > 0:
|
||||
# Get the user query from a more recent checkpoint
|
||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||
prev_channel_values = prev_checkpoint_data.get(
|
||||
"channel_values", {}
|
||||
)
|
||||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
q, imgs = split_langchain_human_content(
|
||||
msg.content
|
||||
)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
if user_query_to_use is not None and (
|
||||
str(user_query_to_use).strip()
|
||||
or regenerate_image_urls
|
||||
):
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
break
|
||||
|
||||
# If we couldn't find a good checkpoint, try alternative approaches
|
||||
if target_checkpoint_id is None and checkpoint_tuples:
|
||||
|
|
@ -1472,18 +1821,51 @@ async def regenerate_response(
|
|||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||
)
|
||||
|
||||
# Get the last two messages to delete AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(2)
|
||||
)
|
||||
# Get the messages to delete AFTER streaming succeeds.
|
||||
# This prevents data loss if streaming fails.
|
||||
#
|
||||
# When ``from_message_id`` is set we slice from that message
|
||||
# forward (using ``created_at`` so we also catch any tool/system
|
||||
# messages persisted into the same turn). Otherwise
|
||||
# we keep the legacy "last 2 messages" rewind.
|
||||
if request.from_message_id is not None and from_message_created_at is not None:
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.created_at >= from_message_created_at,
|
||||
)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
)
|
||||
else:
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(2)
|
||||
)
|
||||
messages_to_delete = list(last_messages_result.scalars().all())
|
||||
|
||||
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
||||
|
||||
# When revert_actions is requested, collect the set of
|
||||
# ``chat_turn_id``s present in the slice we're about to delete.
|
||||
# Each one will be reverted (best-effort) BEFORE the regenerate
|
||||
# stream begins. Legacy rows have ``turn_id=None`` and silently
|
||||
# contribute nothing — we already logged the degradation above.
|
||||
revert_turn_ids: list[str] = []
|
||||
if (
|
||||
request.revert_actions
|
||||
and request.from_message_id is not None
|
||||
and not legacy_from_message
|
||||
):
|
||||
seen_turns: set[str] = set()
|
||||
for msg in messages_to_delete:
|
||||
tid = msg.turn_id
|
||||
if tid and tid not in seen_turns:
|
||||
seen_turns.add(tid)
|
||||
revert_turn_ids.append(tid)
|
||||
|
||||
# Get search space for LLM config
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
@ -1507,6 +1889,24 @@ async def regenerate_response(
|
|||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||
async def stream_with_cleanup():
|
||||
streaming_completed = False
|
||||
# Best-effort revert pass BEFORE the regenerate stream begins.
|
||||
# Each turn is reverted independently (per-row SAVEPOINTs
|
||||
# inside the route helper) and the per-action results are surfaced
|
||||
# on a single ``data-revert-results`` SSE event so the frontend
|
||||
# can render any failed rows alongside the new turn. Failures here
|
||||
# do NOT abort the regeneration — partial rollback is documented
|
||||
# behaviour.
|
||||
if revert_turn_ids:
|
||||
revert_results = await _revert_turns_for_regenerate(
|
||||
thread_id=thread_id,
|
||||
chat_turn_ids=revert_turn_ids,
|
||||
requester_user_id=str(user.id),
|
||||
)
|
||||
envelope = {
|
||||
"type": "data-revert-results",
|
||||
"data": revert_results,
|
||||
}
|
||||
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||
try:
|
||||
async for chunk in stream_new_chat(
|
||||
user_query=str(user_query_to_use),
|
||||
|
|
|
|||
|
|
@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
|||
author_display_name: str | None = None
|
||||
author_avatar_url: str | None = None
|
||||
token_usage: TokenUsageSummary | None = None
|
||||
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
|
||||
# ``configurable.turn_id`` at streaming time. Nullable because
|
||||
# legacy rows predate the column; clients should treat NULL as
|
||||
# "edit-from-this-message is unavailable".
|
||||
turn_id: str | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
|
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
|
|||
|
||||
For edit, optional user_images (when not None) replaces image URLs resolved from
|
||||
checkpoint/DB so the client can send the full user turn (text and/or images).
|
||||
|
||||
Edit-from-arbitrary-position. When ``from_message_id`` is provided
|
||||
the route slices conversation history starting at that message (instead of
|
||||
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
|
||||
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
|
||||
optionally reverts every reversible action emitted in turns at or after
|
||||
``from_message_id``. The revert step is best-effort and runs BEFORE the
|
||||
regenerate stream — partial failures are surfaced via SSE
|
||||
``data-revert-results`` and do not abort the regeneration.
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
|
|
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
|
|||
default=None,
|
||||
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
||||
)
|
||||
from_message_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Message id to rewind to. When set, history is sliced "
|
||||
"from this message forward and the LangGraph checkpoint is "
|
||||
"rewound to the state immediately preceding this turn. Legacy "
|
||||
"rows that predate migration 136 have ``turn_id=None`` and "
|
||||
"still process — the route logs a warning, skips the "
|
||||
"checkpoint rewind, and ignores ``revert_actions`` (no "
|
||||
"chat_turn_id available to walk)."
|
||||
),
|
||||
)
|
||||
revert_actions: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"When true, every reversible action emitted at or "
|
||||
"after ``from_message_id`` is reverted before the regenerate "
|
||||
"stream begins. Per-action results are surfaced via the "
|
||||
"``data-revert-results`` SSE event. Partial failures DO NOT "
|
||||
"abort the regeneration."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_regenerate_user_images(self) -> Self:
|
||||
|
|
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
|
|||
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_revert_actions_requires_from_message(self) -> Self:
|
||||
if self.revert_actions and self.from_message_id is None:
|
||||
raise ValueError(
|
||||
"revert_actions requires from_message_id; specify which message to rewind to"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Agent Tools Schemas
|
||||
|
|
|
|||
|
|
@ -584,13 +584,33 @@ class VercelStreamingService:
|
|||
# Tool Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
|
||||
def format_tool_input_start(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the start of tool input streaming.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique tool call identifier
|
||||
tool_name: The name of the tool being called
|
||||
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||
path: when the provider streams ``tool_call_chunks`` we
|
||||
register the ``index`` and reuse the lc-id as the card
|
||||
id so live ``tool-input-delta`` events can be routed
|
||||
without a downstream join). Either way, the same id is
|
||||
preserved across ``tool-input-start`` / ``-delta`` /
|
||||
``-available`` / ``tool-output-available`` for one call.
|
||||
tool_name: The name of the tool being called.
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id``. When set, surfaces as
|
||||
``langchainToolCallId`` so the frontend can join this card
|
||||
to the action-log row written by ``ActionLogMiddleware``.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input start part
|
||||
|
|
@ -598,13 +618,14 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||
"""
|
||||
|
|
@ -629,7 +650,12 @@ class VercelStreamingService:
|
|||
)
|
||||
|
||||
def format_tool_input_available(
|
||||
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the completion of tool input.
|
||||
|
|
@ -638,6 +664,8 @@ class VercelStreamingService:
|
|||
tool_call_id: The tool call identifier
|
||||
tool_name: The name of the tool
|
||||
input_data: The complete tool input parameters
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id`` (see ``format_tool_input_start``).
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input available part
|
||||
|
|
@ -645,22 +673,34 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
|
||||
def format_tool_output_available(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format tool execution output.
|
||||
|
||||
Args:
|
||||
tool_call_id: The tool call identifier
|
||||
output: The tool execution result
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
|
||||
When set, the frontend can backfill any card whose
|
||||
``langchainToolCallId`` was not yet known at
|
||||
``tool-input-start`` time.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool output available part
|
||||
|
|
@ -668,13 +708,14 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
# Step Parts
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
|
|||
|
||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||
written before the original mutation.
|
||||
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
|
||||
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
|
||||
that was created; everything else is an in-place restore.
|
||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||
the inverse tool through the agent's normal permission stack (NOT
|
||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||
|
|
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
|
|||
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||
``reverse_of=<original_action_id>`` and the requesting user's
|
||||
``user_id``, preserving an auditable chain.
|
||||
|
||||
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
|
||||
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
|
||||
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
|
||||
same trap waiting to happen).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -25,17 +32,31 @@ from __future__ import annotations
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
safe_filename,
|
||||
safe_folder_segment,
|
||||
)
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentRevision,
|
||||
DocumentType,
|
||||
Folder,
|
||||
FolderRevision,
|
||||
NewChatThread,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
embed_texts,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -110,14 +131,244 @@ def can_revert(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revert paths
|
||||
# Helper: reconstruct virtual path from a snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _virtual_path_from_snapshot(
|
||||
session: AsyncSession,
|
||||
revision: DocumentRevision,
|
||||
) -> str | None:
|
||||
"""Reconstruct the virtual_path the document was at before mutation.
|
||||
|
||||
Preference order:
|
||||
1. ``metadata_before["virtual_path"]`` — written by every snapshot
|
||||
helper since this PR.
|
||||
2. Compose ``"<folder_path>/<title_before>"`` from
|
||||
``folder_id_before`` + ``title_before``. Walks the folder chain via
|
||||
``parent_id``.
|
||||
"""
|
||||
metadata = revision.metadata_before or {}
|
||||
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
|
||||
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
|
||||
return candidate
|
||||
|
||||
title = revision.title_before
|
||||
if not isinstance(title, str) or not title:
|
||||
return None
|
||||
|
||||
parts: list[str] = []
|
||||
cursor: int | None = revision.folder_id_before
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
folder = await session.get(Folder, cursor)
|
||||
if folder is None:
|
||||
return None
|
||||
parts.append(safe_folder_segment(str(folder.name or "")))
|
||||
cursor = folder.parent_id
|
||||
parts.reverse()
|
||||
|
||||
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
filename = safe_filename(title)
|
||||
return f"{base}/{filename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document revision restore (write/edit/move/rm)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_field(target: Any, field: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
setattr(target, field, value)
|
||||
|
||||
|
||||
async def _restore_in_place_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Apply an in-place restore to an existing :class:`Document`."""
|
||||
if revision.document_id is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Original document was hard-deleted; in-place restore is not possible."
|
||||
),
|
||||
)
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
_set_field(doc, "content", revision.content_before)
|
||||
_set_field(doc, "source_markdown", revision.content_before)
|
||||
_set_field(doc, "title", revision.title_before)
|
||||
_set_field(doc, "folder_id", revision.folder_id_before)
|
||||
metadata_before = revision.metadata_before or {}
|
||||
if isinstance(metadata_before, dict) and metadata_before:
|
||||
doc.document_metadata = dict(metadata_before)
|
||||
|
||||
if isinstance(revision.content_before, str):
|
||||
doc.content_hash = generate_content_hash(
|
||||
revision.content_before, doc.search_space_id
|
||||
)
|
||||
|
||||
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||
if virtual_path:
|
||||
doc.unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
doc.search_space_id,
|
||||
)
|
||||
|
||||
chunks_before = revision.chunks_before
|
||||
if isinstance(chunks_before, list):
|
||||
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
|
||||
chunk_texts = [
|
||||
str(c.get("content"))
|
||||
for c in chunks_before
|
||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||
]
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(
|
||||
chunk_texts, chunk_embeddings, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
if isinstance(revision.content_before, str):
|
||||
doc.embedding = embed_texts([revision.content_before])[0]
|
||||
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
|
||||
|
||||
async def _reinsert_document_from_revision(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
|
||||
if not isinstance(revision.title_before, str) or not revision.title_before:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks title_before; cannot recreate document.",
|
||||
)
|
||||
if not isinstance(revision.content_before, str):
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks content_before; cannot recreate document.",
|
||||
)
|
||||
|
||||
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||
if not virtual_path:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message=(
|
||||
"Snapshot is missing both metadata_before['virtual_path'] AND "
|
||||
"a resolvable (folder_id_before, title_before) pair."
|
||||
),
|
||||
)
|
||||
|
||||
search_space_id = revision.search_space_id
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
collision = await session.execute(
|
||||
select(Document.id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_identifier_hash,
|
||||
)
|
||||
)
|
||||
if collision.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
f"A document already exists at '{virtual_path}'; revert would "
|
||||
"collide. Move the live doc out of the way first."
|
||||
),
|
||||
)
|
||||
|
||||
metadata = revision.metadata_before or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
metadata = dict(metadata)
|
||||
metadata["virtual_path"] = virtual_path
|
||||
|
||||
content = revision.content_before
|
||||
new_doc = Document(
|
||||
title=revision.title_before,
|
||||
document_type=DocumentType.NOTE,
|
||||
document_metadata=metadata,
|
||||
content=content,
|
||||
content_hash=generate_content_hash(content, search_space_id),
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=content,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=revision.folder_id_before,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(new_doc)
|
||||
await session.flush()
|
||||
|
||||
new_doc.embedding = embed_texts([content])[0]
|
||||
chunk_texts = []
|
||||
chunks_before = revision.chunks_before
|
||||
if isinstance(chunks_before, list):
|
||||
chunk_texts = [
|
||||
str(c.get("content"))
|
||||
for c in chunks_before
|
||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||
]
|
||||
if chunk_texts:
|
||||
chunk_embeddings = embed_texts(chunk_texts)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Repoint the snapshot at the recreated row so a follow-up revert of
|
||||
# the same row works as expected.
|
||||
revision.document_id = new_doc.id
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
|
||||
)
|
||||
|
||||
|
||||
async def _delete_created_document(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: DocumentRevision,
|
||||
) -> RevertOutcome:
|
||||
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
|
||||
if revision.document_id is None:
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="No live row to delete (already removed elsewhere).",
|
||||
)
|
||||
await session.execute(delete(Document).where(Document.id == revision.document_id))
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="Deleted the document that was created by this action.",
|
||||
)
|
||||
|
||||
|
||||
async def _restore_document_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
||||
"""Dispatch document-level revert based on ``action.tool_name``."""
|
||||
stmt = (
|
||||
select(DocumentRevision)
|
||||
.where(DocumentRevision.agent_action_id == action.id)
|
||||
|
|
@ -132,23 +383,111 @@ async def _restore_document_revision(
|
|||
message="No document_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Document # late import to avoid cycles at module load
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
if tool_name == "rm":
|
||||
return await _reinsert_document_from_revision(session, revision=revision)
|
||||
|
||||
if tool_name == "write_file" and revision.content_before is None:
|
||||
return await _delete_created_document(session, revision=revision)
|
||||
|
||||
return await _restore_in_place_document(session, revision=revision)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Folder revision restore (mkdir/rmdir/rename/move)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _restore_in_place_folder(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if revision.folder_id is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
message="Original folder was hard-deleted; in-place restore is impossible.",
|
||||
)
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
_set_field(folder, "name", revision.name_before)
|
||||
_set_field(folder, "parent_id", revision.parent_id_before)
|
||||
_set_field(folder, "position", revision.position_before)
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
|
||||
|
||||
async def _reinsert_folder_from_revision(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if not isinstance(revision.name_before, str) or not revision.name_before:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="Snapshot lacks name_before; cannot recreate folder.",
|
||||
)
|
||||
new_folder = Folder(
|
||||
name=revision.name_before,
|
||||
parent_id=revision.parent_id_before,
|
||||
position=revision.position_before,
|
||||
search_space_id=revision.search_space_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(new_folder)
|
||||
await session.flush()
|
||||
revision.folder_id = new_folder.id
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
|
||||
)
|
||||
|
||||
|
||||
async def _delete_created_folder(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
revision: FolderRevision,
|
||||
) -> RevertOutcome:
|
||||
if revision.folder_id is None:
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="No live folder row to delete (already removed elsewhere).",
|
||||
)
|
||||
folder_id = revision.folder_id
|
||||
|
||||
has_doc = await session.execute(
|
||||
select(Document.id).where(Document.folder_id == folder_id).limit(1)
|
||||
)
|
||||
if has_doc.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Folder is no longer empty (documents have been added since "
|
||||
"mkdir); cannot revert."
|
||||
),
|
||||
)
|
||||
has_child = await session.execute(
|
||||
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
|
||||
)
|
||||
if has_child.scalar_one_or_none() is not None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message=(
|
||||
"Folder is no longer empty (sub-folders have been added "
|
||||
"since mkdir); cannot revert."
|
||||
),
|
||||
)
|
||||
|
||||
if revision.content_before is not None:
|
||||
doc.content = revision.content_before
|
||||
if revision.title_before is not None:
|
||||
doc.title = revision.title_before
|
||||
if revision.folder_id_before is not None:
|
||||
doc.folder_id = revision.folder_id_before
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
await session.execute(delete(Folder).where(Folder.id == folder_id))
|
||||
return RevertOutcome(
|
||||
status="ok",
|
||||
message="Deleted the folder that was created by this action.",
|
||||
)
|
||||
|
||||
|
||||
async def _restore_folder_revision(
|
||||
|
|
@ -168,41 +507,44 @@ async def _restore_folder_revision(
|
|||
message="No folder_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Folder
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
if tool_name == "rmdir":
|
||||
return await _reinsert_folder_from_revision(session, revision=revision)
|
||||
|
||||
if revision.name_before is not None:
|
||||
folder.name = revision.name_before
|
||||
if revision.parent_id_before is not None:
|
||||
folder.parent_id = revision.parent_id_before
|
||||
if revision.position_before is not None:
|
||||
folder.position = revision.position_before
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
if tool_name == "mkdir":
|
||||
return await _delete_created_folder(session, revision=revision)
|
||||
|
||||
return await _restore_in_place_folder(session, revision=revision)
|
||||
|
||||
|
||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
||||
# this module's control flow.
|
||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
|
||||
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
|
||||
# ``delete_note``/``delete_folder``.
|
||||
|
||||
_DOC_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"move_file",
|
||||
"rm",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
}
|
||||
)
|
||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"mkdir",
|
||||
"move_file",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
_FOLDER_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"mkdir",
|
||||
"rmdir",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,9 +562,9 @@ async def revert_action(
|
|||
"""
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
||||
if tool_name in _DOC_TOOLS:
|
||||
outcome = await _restore_document_revision(session, action=action)
|
||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
||||
elif tool_name in _FOLDER_TOOLS:
|
||||
outcome = await _restore_folder_revision(session, action=action)
|
||||
elif action.reverse_descriptor:
|
||||
# Connector-owned reversibles run through the normal permission
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
|
|||
from app.agents.multi_agent_chat.integration import create_multi_agent_chat
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
|
|
@ -72,6 +73,91 @@ _perf_log = get_perf_logger()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||
|
||||
Returns a dict with three keys:
|
||||
|
||||
* ``text`` — concatenated string content (empty string if the chunk
|
||||
contributes none).
|
||||
* ``reasoning`` — concatenated reasoning content (empty string if the
|
||||
chunk contributes none).
|
||||
* ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk``
|
||||
dicts surfaced from either the typed-block list or the
|
||||
``tool_call_chunks`` attribute.
|
||||
|
||||
Background
|
||||
----------
|
||||
``AIMessageChunk.content`` can be:
|
||||
|
||||
* a ``str`` (most providers), or
|
||||
* a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
|
||||
'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
|
||||
Anthropic, Bedrock, and several reasoning configurations.
|
||||
|
||||
Reasoning may also live under
|
||||
``chunk.additional_kwargs['reasoning_content']`` (some providers
|
||||
surface it that way instead of as a typed block). Tool-call chunks
|
||||
may live under ``chunk.tool_call_chunks`` even when ``content`` is a
|
||||
plain string.
|
||||
|
||||
Earlier versions only handled the ``isinstance(content, str)`` branch
|
||||
and silently dropped reasoning blocks + tool-call chunks emitted by
|
||||
LangChain ``AIMessageChunk``s.
|
||||
"""
|
||||
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
|
||||
if chunk is None:
|
||||
return out
|
||||
|
||||
content = getattr(chunk, "content", None)
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
out["text"] = content
|
||||
elif isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
if isinstance(value, str) and value:
|
||||
text_parts.append(value)
|
||||
elif block_type == "reasoning":
|
||||
value = (
|
||||
block.get("reasoning")
|
||||
or block.get("text")
|
||||
or block.get("content")
|
||||
or ""
|
||||
)
|
||||
if isinstance(value, str) and value:
|
||||
reasoning_parts.append(value)
|
||||
elif block_type in ("tool_call_chunk", "tool_use"):
|
||||
out["tool_call_chunks"].append(block)
|
||||
if text_parts:
|
||||
out["text"] = "".join(text_parts)
|
||||
if reasoning_parts:
|
||||
out["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
additional = getattr(chunk, "additional_kwargs", None) or {}
|
||||
if isinstance(additional, dict):
|
||||
extra_reasoning = additional.get("reasoning_content")
|
||||
if isinstance(extra_reasoning, str) and extra_reasoning:
|
||||
existing = out["reasoning"]
|
||||
out["reasoning"] = (
|
||||
(existing + extra_reasoning) if existing else extra_reasoning
|
||||
)
|
||||
|
||||
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
|
||||
if isinstance(extra_tool_chunks, list):
|
||||
for tcc in extra_tool_chunks:
|
||||
if isinstance(tcc, dict):
|
||||
out["tool_call_chunks"].append(tcc)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
) -> str:
|
||||
|
|
@ -254,6 +340,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
pending_tool_call_chunks: list[dict[str, Any]],
|
||||
tool_name: str,
|
||||
run_id: str,
|
||||
lc_tool_call_id_by_run: dict[str, str],
|
||||
) -> str | None:
|
||||
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
|
||||
|
||||
Pure extract of the legacy in-line match used at ``on_tool_start`` for
|
||||
parity_v2-OFF and unmatched (chunk path didn't register an index for
|
||||
this call) tools. Pops the next id-bearing chunk whose ``name``
|
||||
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
|
||||
returns its id. Mutates ``pending_tool_call_chunks`` and
|
||||
``lc_tool_call_id_by_run`` in place.
|
||||
"""
|
||||
matched_idx: int | None = None
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
return None
|
||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||
candidate = matched.get("id")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = candidate
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
|
|
@ -268,6 +390,7 @@ async def _stream_agent_events(
|
|||
fallback_commit_search_space_id: int | None = None,
|
||||
fallback_commit_created_by_id: str | None = None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
fallback_commit_thread_id: int | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Shared async generator that streams and formats astream_events from the agent.
|
||||
|
||||
|
|
@ -300,6 +423,59 @@ async def _stream_agent_events(
|
|||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||
called_update_memory: bool = False
|
||||
|
||||
# Reasoning-block streaming. We open a reasoning block on the
|
||||
# first reasoning delta of a step, append deltas as they arrive, and
|
||||
# close it when text starts (the model has switched to writing its
|
||||
# answer) or ``on_chat_model_end`` fires for the model node. Reuses
|
||||
# the same Vercel format-helpers as text-start/delta/end.
|
||||
current_reasoning_id: str | None = None
|
||||
|
||||
# Streaming-parity v2 feature flag. When OFF we keep the legacy
|
||||
# shape: str-only content, no reasoning blocks, no
|
||||
# ``langchainToolCallId`` propagation. The schema migrations
|
||||
# (135 / 136) ship unconditionally because they're forward-compatible.
|
||||
parity_v2 = bool(get_flags().enable_stream_parity_v2)
|
||||
|
||||
# Best-effort attach of LangChain ``tool_call_id`` to the synthetic
|
||||
# ``call_<run_id>`` card id we already emit. We accumulate
|
||||
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
||||
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
||||
# authoritative id is later filled in at ``on_tool_end`` from
|
||||
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
|
||||
# this list for chunks that already registered into ``index_to_meta``
|
||||
# below — so this list is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path only and never re-pops a chunk we already streamed.
|
||||
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
|
||||
# is keyed by the chunk's ``index`` field — LangChain
|
||||
# ``ToolCallChunk``s for the same call share an index but only the
|
||||
# first chunk carries id+name (subsequent ones are id=None,
|
||||
# name=None, args="<delta>"). We register an index when both id and
|
||||
# name are observed on a chunk (per ToolCallChunk semantics they
|
||||
# arrive together on the first chunk), then route every later chunk
|
||||
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
|
||||
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
|
||||
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
|
||||
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
|
||||
# the same card.
|
||||
index_to_meta: dict[int, dict[str, str]] = {}
|
||||
ui_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
||||
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
||||
# ``format_tool_output_available`` call automatically carries the
|
||||
# authoritative id without duplicating the kwarg at every call site.
|
||||
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
||||
|
||||
def _emit_tool_output(call_id: str, output: Any) -> str:
|
||||
return streaming_service.format_tool_output_available(
|
||||
call_id,
|
||||
output,
|
||||
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
||||
)
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
|
|
@ -328,22 +504,119 @@ async def _stream_agent_events(
|
|||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
if current_text_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(current_text_id, content)
|
||||
accumulated_text += content
|
||||
if not chunk:
|
||||
continue
|
||||
parts = _extract_chunk_parts(chunk)
|
||||
|
||||
reasoning_delta = parts["reasoning"]
|
||||
text_delta = parts["text"]
|
||||
|
||||
# Reasoning streaming. Open a reasoning block on first
|
||||
# delta; append every subsequent delta until text begins.
|
||||
# When text starts we close the reasoning block first so the
|
||||
# frontend sees the natural hand-off. Gated behind the
|
||||
# parity-v2 flag so legacy deployments keep today's shape.
|
||||
if parity_v2 and reasoning_delta:
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
||||
yield streaming_service.format_reasoning_delta(
|
||||
current_reasoning_id, reasoning_delta
|
||||
)
|
||||
|
||||
if text_delta:
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
||||
current_reasoning_id = None
|
||||
if current_text_id is None:
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
yield completion_event
|
||||
if just_finished_tool:
|
||||
last_active_step_id = None
|
||||
last_active_step_title = ""
|
||||
last_active_step_items = []
|
||||
just_finished_tool = False
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||
accumulated_text += text_delta
|
||||
|
||||
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||
# processing so chunks containing both stay in their natural
|
||||
# wire order (text → text-end → tool-input-start). Active
|
||||
# text/reasoning are closed inside the registration branch
|
||||
# before ``tool-input-start`` so the frontend sees a clean
|
||||
# part boundary even when providers interleave.
|
||||
if parity_v2 and parts["tool_call_chunks"]:
|
||||
for tcc in parts["tool_call_chunks"]:
|
||||
idx = tcc.get("index")
|
||||
|
||||
# Register this index when we first see id+name
|
||||
# TOGETHER. Per LangChain ToolCallChunk semantics the
|
||||
# first chunk for a tool call carries both fields
|
||||
# together; later chunks have id=None, name=None and
|
||||
# only ``args``. Requiring BOTH keeps wire
|
||||
# ``tool-input-start`` always carrying a real
|
||||
# toolName (assistant-ui's typed tool-part dispatch
|
||||
# keys off it).
|
||||
if idx is not None and idx not in index_to_meta:
|
||||
lc_id = tcc.get("id")
|
||||
name = tcc.get("name")
|
||||
if lc_id and name:
|
||||
ui_id = lc_id
|
||||
|
||||
# Close active text/reasoning so wire
|
||||
# ordering stays clean even on providers
|
||||
# that interleave text and tool-call chunks
|
||||
# within the same stream window.
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(
|
||||
current_reasoning_id
|
||||
)
|
||||
current_reasoning_id = None
|
||||
|
||||
index_to_meta[idx] = {
|
||||
"ui_id": ui_id,
|
||||
"lc_id": lc_id,
|
||||
"name": name,
|
||||
}
|
||||
yield streaming_service.format_tool_input_start(
|
||||
ui_id,
|
||||
name,
|
||||
langchain_tool_call_id=lc_id,
|
||||
)
|
||||
|
||||
# Emit args delta for any chunk at a registered
|
||||
# index (including idless continuations). Once an
|
||||
# index is owned by ``index_to_meta`` we DO NOT
|
||||
# append to ``pending_tool_call_chunks`` — that list
|
||||
# is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path so it never re-pops chunks already
|
||||
# consumed here (skip-append).
|
||||
meta = index_to_meta.get(idx) if idx is not None else None
|
||||
if meta:
|
||||
args_chunk = tcc.get("args") or ""
|
||||
if args_chunk:
|
||||
yield streaming_service.format_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
else:
|
||||
pending_tool_call_chunks.append(tcc)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
|
|
@ -463,6 +736,95 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rm":
|
||||
rm_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
|
||||
last_active_step_title = "Deleting file"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting file",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rmdir":
|
||||
rmdir_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = (
|
||||
rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:]
|
||||
)
|
||||
last_active_step_title = "Deleting folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting folder",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "mkdir":
|
||||
mkdir_path = (
|
||||
tool_input.get("path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_path = (
|
||||
mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:]
|
||||
)
|
||||
last_active_step_title = "Creating folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Creating folder",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "move_file":
|
||||
src = (
|
||||
tool_input.get("source_path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else ""
|
||||
)
|
||||
dst = (
|
||||
tool_input.get("destination_path", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else ""
|
||||
)
|
||||
display_src = src if len(src) <= 60 else "…" + src[-57:]
|
||||
display_dst = dst if len(dst) <= 60 else "…" + dst[-57:]
|
||||
last_active_step_title = "Moving file"
|
||||
last_active_step_items = (
|
||||
[f"{display_src} → {display_dst}"] if src or dst else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Moving file",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_todos":
|
||||
todos = (
|
||||
tool_input.get("todos", []) if isinstance(tool_input, dict) else []
|
||||
)
|
||||
todo_count = len(todos) if isinstance(todos, list) else 0
|
||||
last_active_step_title = "Planning tasks"
|
||||
last_active_step_items = (
|
||||
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
|
||||
if todo_count
|
||||
else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Planning tasks",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "save_document":
|
||||
doc_title = (
|
||||
tool_input.get("title", "")
|
||||
|
|
@ -570,7 +932,15 @@ async def _stream_agent_events(
|
|||
items=last_active_step_items,
|
||||
)
|
||||
else:
|
||||
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
|
||||
# Fallback for tools without a curated thinking-step title
|
||||
# (typically connector tools, MCP-registered tools, or
|
||||
# newly added tools that haven't been wired up here yet).
|
||||
# Render the snake_cased name as a sentence-cased phrase
|
||||
# so non-technical users see e.g. "Send gmail email"
|
||||
# rather than the raw identifier "send_gmail_email".
|
||||
last_active_step_title = (
|
||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
last_active_step_items = []
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
|
|
@ -578,12 +948,65 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
)
|
||||
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
||||
# Resolve the card identity. If the chunk-emission loop
|
||||
# already registered an ``index`` for this tool call (parity_v2
|
||||
# path), reuse the same ui_id so the card sees:
|
||||
# tool-input-start → deltas… → tool-input-available →
|
||||
# tool-output-available all keyed by lc_id. Otherwise fall
|
||||
# back to the synthetic ``call_<run_id>`` id and the legacy
|
||||
# best-effort match against ``pending_tool_call_chunks``.
|
||||
matched_meta: dict[str, str] | None = None
|
||||
if parity_v2:
|
||||
# FIFO over indices 0,1,2…; first unassigned same-name
|
||||
# match wins. Handles parallel same-name calls (e.g. two
|
||||
# write_file calls) deterministically as long as the
|
||||
# model interleaves on_tool_start in the same order it
|
||||
# streamed the args.
|
||||
taken_ui_ids = set(ui_tool_call_id_by_run.values())
|
||||
for meta in index_to_meta.values():
|
||||
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||
matched_meta = meta
|
||||
break
|
||||
|
||||
tool_call_id: str
|
||||
langchain_tool_call_id: str | None = None
|
||||
if matched_meta is not None:
|
||||
tool_call_id = matched_meta["ui_id"]
|
||||
langchain_tool_call_id = matched_meta["lc_id"]
|
||||
# ``tool-input-start`` already fired during chunk
|
||||
# emission — skip the duplicate. No pruning is needed
|
||||
# because the chunk-emission loop intentionally never
|
||||
# appends registered-index chunks to
|
||||
# ``pending_tool_call_chunks`` (skip-append).
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
|
||||
else:
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
|
||||
# provider didn't stream tool_call_chunks for this call
|
||||
# (no index registered). Run the existing best-effort
|
||||
# match BEFORE emitting start so we still attach an
|
||||
# authoritative ``langchainToolCallId`` when possible.
|
||||
if parity_v2:
|
||||
langchain_tool_call_id = _legacy_match_lc_id(
|
||||
pending_tool_call_chunks,
|
||||
tool_name,
|
||||
run_id,
|
||||
lc_tool_call_id_by_run,
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
if run_id:
|
||||
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||
|
||||
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||
if isinstance(tool_input, dict):
|
||||
|
|
@ -600,6 +1023,7 @@ async def _stream_agent_events(
|
|||
tool_call_id,
|
||||
tool_name,
|
||||
_safe_input,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
|
|
@ -635,12 +1059,42 @@ async def _stream_agent_events(
|
|||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
# Look up the SAME card id used at on_tool_start (either the
|
||||
# parity_v2 lc-id-derived ui_id or the legacy synthetic
|
||||
# ``call_<run_id>``) so the output event always lands on the
|
||||
# same card as start/delta/available. Fallback preserves the
|
||||
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
|
||||
tool_call_id = ui_tool_call_id_by_run.get(
|
||||
run_id,
|
||||
f"call_{run_id[:32]}" if run_id else "call_unknown",
|
||||
)
|
||||
original_step_id = tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
)
|
||||
completed_step_ids.add(original_step_id)
|
||||
|
||||
# Authoritative LangChain tool_call_id from the returned
|
||||
# ``ToolMessage``. Falls back to whatever we matched
|
||||
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
||||
# if the output isn't a ToolMessage. The value is stored in
|
||||
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
||||
# picks it up for every output emit below.
|
||||
#
|
||||
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
|
||||
# card needs the LangChain id to match against the
|
||||
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
|
||||
# so the inline Revert button can light up. Reading
|
||||
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
|
||||
# access that is safe regardless of feature-flag state.
|
||||
current_lc_tool_call_id["value"] = None
|
||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||
if isinstance(authoritative, str) and authoritative:
|
||||
current_lc_tool_call_id["value"] = authoritative
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = authoritative
|
||||
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||
|
||||
if tool_name == "read_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
|
|
@ -676,6 +1130,41 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rm":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rmdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "mkdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Creating folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "move_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Moving file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_todos":
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Planning tasks",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "save_document":
|
||||
result_str = (
|
||||
tool_output.get("result", "")
|
||||
|
|
@ -927,9 +1416,14 @@ async def _stream_agent_events(
|
|||
items=completed_items,
|
||||
)
|
||||
else:
|
||||
# Fallback completion title — see the matching in-progress
|
||||
# branch above for the wording rationale.
|
||||
fallback_title = (
|
||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title=f"Using {tool_name.replace('_', ' ')}",
|
||||
title=fallback_title,
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
|
|
@ -940,7 +1434,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = []
|
||||
|
||||
if tool_name == "generate_podcast":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -965,7 +1459,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_video_presentation":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -993,7 +1487,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_image":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1020,12 +1514,12 @@ async def _stream_agent_events(
|
|||
display_output["content_preview"] = (
|
||||
content[:500] + "..." if len(content) > 500 else content
|
||||
)
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
display_output,
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"result": tool_output},
|
||||
)
|
||||
|
|
@ -1053,7 +1547,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
result_text = _tool_output_to_text(tool_output)
|
||||
if _tool_output_has_error(tool_output):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "error",
|
||||
|
|
@ -1062,7 +1556,7 @@ async def _stream_agent_events(
|
|||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"status": "completed",
|
||||
|
|
@ -1072,7 +1566,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
elif tool_name == "generate_report":
|
||||
# Stream the full report result so frontend can render the ReportCard
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1099,7 +1593,7 @@ async def _stream_agent_events(
|
|||
"error",
|
||||
)
|
||||
elif tool_name == "generate_resume":
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1150,7 +1644,7 @@ async def _stream_agent_events(
|
|||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1178,7 +1672,7 @@ async def _stream_agent_events(
|
|||
if fpath and fpath not in result.sandbox_files:
|
||||
result.sandbox_files.append(fpath)
|
||||
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
|
|
@ -1213,12 +1707,12 @@ async def _stream_agent_events(
|
|||
citations[chunk_url]["snippet"] = (
|
||||
content[:200] + "…" if len(content) > 200 else content
|
||||
)
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"status": "completed", "citations": citations},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
yield _emit_tool_output(
|
||||
tool_call_id,
|
||||
{"status": "completed", "result_length": len(str(tool_output))},
|
||||
)
|
||||
|
|
@ -1276,6 +1770,25 @@ async def _stream_agent_events(
|
|||
},
|
||||
)
|
||||
|
||||
elif event_type == "on_custom_event" and event.get("name") == "action_log":
|
||||
# Surface a freshly committed AgentActionLog row so the chat
|
||||
# tool card can render its Revert button immediately.
|
||||
data = event.get("data", {})
|
||||
if data.get("id") is not None:
|
||||
yield streaming_service.format_data("action-log", data)
|
||||
|
||||
elif (
|
||||
event_type == "on_custom_event"
|
||||
and event.get("name") == "action_log_updated"
|
||||
):
|
||||
# Reversibility flipped in kb_persistence after the SAVEPOINT
|
||||
# for a destructive op (rm/rmdir/move/edit/write) committed.
|
||||
# Frontend uses this to flip the card's Revert
|
||||
# button on without re-fetching the actions list.
|
||||
data = event.get("data", {})
|
||||
if data.get("id") is not None:
|
||||
yield streaming_service.format_data("action-log-updated", data)
|
||||
|
||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -1293,11 +1806,12 @@ async def _stream_agent_events(
|
|||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
# (dirty_paths / staged_dirs / pending_moves) will still be in the
|
||||
# checkpointed state. Run the SAME shared commit helper here so the
|
||||
# turn's writes don't get lost on client disconnect, then push the
|
||||
# delta back into the graph using `as_node=...` so reducers fire as if
|
||||
# the after_agent hook produced it.
|
||||
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
|
||||
# pending_dir_deletes) will still be in the checkpointed state. Run
|
||||
# the SAME shared commit helper here so the turn's writes don't get
|
||||
# lost on client disconnect, then push the delta back into the graph
|
||||
# using `as_node=...` so reducers fire as if the after_agent hook
|
||||
# produced it.
|
||||
if (
|
||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||
and fallback_commit_search_space_id is not None
|
||||
|
|
@ -1305,6 +1819,8 @@ async def _stream_agent_events(
|
|||
(state_values.get("dirty_paths") or [])
|
||||
or (state_values.get("staged_dirs") or [])
|
||||
or (state_values.get("pending_moves") or [])
|
||||
or (state_values.get("pending_deletes") or [])
|
||||
or (state_values.get("pending_dir_deletes") or [])
|
||||
)
|
||||
):
|
||||
try:
|
||||
|
|
@ -1313,6 +1829,7 @@ async def _stream_agent_events(
|
|||
search_space_id=fallback_commit_search_space_id,
|
||||
created_by_id=fallback_commit_created_by_id,
|
||||
filesystem_mode=fallback_commit_filesystem_mode,
|
||||
thread_id=fallback_commit_thread_id,
|
||||
dispatch_events=False,
|
||||
)
|
||||
if delta:
|
||||
|
|
@ -1753,13 +2270,33 @@ async def stream_new_chat(
|
|||
|
||||
config = {
|
||||
"configurable": configurable,
|
||||
"recursion_limit": 80, # Increase from default 25 to allow more tool iterations
|
||||
# Effectively uncapped, matching the agent-level
|
||||
# ``with_config`` default in ``chat_deepagent.create_agent``
|
||||
# and the unbounded ``while(true)`` loop used by OpenCode's
|
||||
# ``session/processor.ts``. Real circuit-breakers live in
|
||||
# middleware: ``DoomLoopMiddleware`` (sliding-window tool
|
||||
# signature check), plus ``enable_tool_call_limit`` /
|
||||
# ``enable_model_call_limit`` when those flags are set. The
|
||||
# original LangGraph default of 25 (and our previous 80
|
||||
# bump) hit users on legitimate multi-tool plans.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
# Start the message stream
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
# Surface the per-turn correlation id at the very start of the
|
||||
# stream so the frontend can stamp it onto the in-flight
|
||||
# assistant message and replay it via ``appendMessage``
|
||||
# for durable storage. Tool/action-log events DO carry it later,
|
||||
# but pure-text turns never produce action-log events; this
|
||||
# event guarantees the frontend learns the turn id regardless.
|
||||
yield streaming_service.format_data(
|
||||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_surfsense_docs:
|
||||
initial_title = "Analyzing referenced content"
|
||||
|
|
@ -1910,6 +2447,7 @@ async def stream_new_chat(
|
|||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
@ -2353,11 +2891,22 @@ async def stream_resume_chat(
|
|||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
},
|
||||
"recursion_limit": 80,
|
||||
# See ``stream_new_chat`` above for rationale: effectively
|
||||
# uncapped to mirror the agent default and OpenCode's
|
||||
# session loop. Doom-loop / call-limit middleware enforce
|
||||
# the real ceiling.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
# Same rationale as ``stream_new_chat``: emit the turn id so
|
||||
# resumed streams can be persisted with their correlation id
|
||||
# intact.
|
||||
yield streaming_service.format_data(
|
||||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
|
|
@ -2375,6 +2924,7 @@ async def stream_resume_chat(
|
|||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue