mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat: real-graph regressions for unified HITL paths + format pass
This commit is contained in:
parent
adb52fb575
commit
0723702320
34 changed files with 920 additions and 90 deletions
|
|
@ -14,7 +14,7 @@ from langgraph.types import Checkpointer
|
|||
from app.agents.multi_agent_chat.middleware.stack import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def build_main_agent_system_prompt(
|
|||
custom_system_instructions: str | None = None,
|
||||
use_default_system_instructions: bool = True,
|
||||
citations_enabled: bool = True,
|
||||
model_name: str | None = None, # noqa: ARG001 — kept for caller compatibility
|
||||
model_name: str | None = None,
|
||||
) -> str:
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
|
@ -62,7 +62,9 @@ def build_main_agent_system_prompt(
|
|||
|
||||
if custom_system_instructions and custom_system_instructions.strip():
|
||||
parts.append(
|
||||
"\n" + custom_system_instructions.format(resolved_today=resolved_today) + "\n"
|
||||
"\n"
|
||||
+ custom_system_instructions.format(resolved_today=resolved_today)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
if use_default_system_instructions:
|
||||
|
|
|
|||
|
|
@ -61,9 +61,7 @@ def slice_decisions_by_tool_call(
|
|||
routed: dict[str, dict[str, Any]] = {}
|
||||
cursor = 0
|
||||
for tool_call_id, action_count in pending_list:
|
||||
routed[tool_call_id] = {
|
||||
"decisions": decisions[cursor : cursor + action_count]
|
||||
}
|
||||
routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]}
|
||||
cursor += action_count
|
||||
return routed
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def check_cloud_write_namespace(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
path: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str | None:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def current_cwd(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None
|
||||
|
|
@ -35,7 +35,7 @@ def current_cwd(
|
|||
|
||||
|
||||
def get_contract_suggested_path(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
"""Read the planner's suggested write path; otherwise default to ``notes.md``."""
|
||||
|
|
@ -47,7 +47,7 @@ def get_contract_suggested_path(
|
|||
|
||||
|
||||
def resolve_relative(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
path: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
|
|
@ -63,7 +63,7 @@ def resolve_relative(
|
|||
|
||||
|
||||
def resolve_write_target_path(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
file_path: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
|
|
@ -77,7 +77,7 @@ def resolve_write_target_path(
|
|||
|
||||
|
||||
def resolve_move_target_path(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
file_path: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
|
|
@ -91,7 +91,7 @@ def resolve_move_target_path(
|
|||
|
||||
|
||||
def resolve_list_target_path(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
path: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
|
|
@ -105,7 +105,7 @@ def resolve_list_target_path(
|
|||
|
||||
|
||||
def normalize_local_mount_path(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
candidate: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from .common import HEADER, SANDBOX_ADDENDUM
|
|||
from .desktop import BODY as DESKTOP_BODY
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
mode: FilesystemMode, *, sandbox_available: bool
|
||||
) -> str:
|
||||
def build_system_prompt(mode: FilesystemMode, *, sandbox_available: bool) -> str:
|
||||
"""Assemble the FS prompt: common header + mode body + optional sandbox section."""
|
||||
body = CLOUD_BODY if mode == FilesystemMode.CLOUD else DESKTOP_BODY
|
||||
base = HEADER + body
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_cd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_cd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_cd(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_edit_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_edit_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_edit_file(
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def wrap_as_python(code: str) -> str:
|
|||
|
||||
|
||||
async def execute_in_sandbox(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
command: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
timeout: int | None,
|
||||
|
|
@ -59,14 +59,12 @@ async def execute_in_sandbox(
|
|||
try:
|
||||
return await _try_sandbox_execute(mw, command, runtime, timeout)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Sandbox retry also failed for thread %s", mw._thread_id
|
||||
)
|
||||
logger.exception("Sandbox retry also failed for thread %s", mw._thread_id)
|
||||
return "Error: Code execution is temporarily unavailable. Please try again."
|
||||
|
||||
|
||||
async def _try_sandbox_execute(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
command: str,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
timeout: int | None,
|
||||
|
|
|
|||
|
|
@ -17,13 +17,11 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_execute_code_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
def sync_execute_code(
|
||||
command: Annotated[
|
||||
str, "Python code to execute. Use print() to see output."
|
||||
],
|
||||
command: Annotated[str, "Python code to execute. Use print() to see output."],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
timeout: Annotated[
|
||||
int | None,
|
||||
|
|
@ -35,14 +33,10 @@ def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
|||
return f"Error: timeout must be non-negative, got {timeout}."
|
||||
if timeout > MAX_EXECUTE_TIMEOUT:
|
||||
return f"Error: timeout {timeout}s exceeds maximum ({MAX_EXECUTE_TIMEOUT}s)."
|
||||
return run_async_blocking(
|
||||
execute_in_sandbox(mw, command, runtime, timeout)
|
||||
)
|
||||
return run_async_blocking(execute_in_sandbox(mw, command, runtime, timeout))
|
||||
|
||||
async def async_execute_code(
|
||||
command: Annotated[
|
||||
str, "Python code to execute. Use print() to see output."
|
||||
],
|
||||
command: Annotated[str, "Python code to execute. Use print() to see output."],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
timeout: Annotated[
|
||||
int | None,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_list_tree_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_list_tree_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_list_tree(
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_ls_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_ls_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_ls(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_mkdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_mkdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_mkdir(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
async def cloud_move_file(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
source: str,
|
||||
dest: str,
|
||||
|
|
@ -39,8 +39,7 @@ async def cloud_move_file(
|
|||
)
|
||||
if not source.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud move_file source must be under /documents/ (got "
|
||||
f"'{source}')."
|
||||
f"Error: cloud move_file source must be under /documents/ (got '{source}')."
|
||||
)
|
||||
if not dest.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
|
|
@ -89,9 +88,7 @@ async def cloud_move_file(
|
|||
],
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Moved '{source}' to '{dest}' (will commit at end of turn)."
|
||||
),
|
||||
content=(f"Moved '{source}' to '{dest}' (will commit at end of turn)."),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
],
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_move_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_move_file(
|
||||
|
|
@ -85,9 +85,7 @@ def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
|||
] = False,
|
||||
) -> Command | str:
|
||||
return run_async_blocking(
|
||||
async_move_file(
|
||||
source_path, destination_path, runtime, overwrite=overwrite
|
||||
)
|
||||
async_move_file(source_path, destination_path, runtime, overwrite=overwrite)
|
||||
)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_pwd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_pwd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
def sync_pwd(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_read_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_read_file(
|
||||
|
|
@ -90,9 +90,7 @@ def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
|||
"Maximum number of lines to read.",
|
||||
] = 100,
|
||||
) -> Command | str:
|
||||
return run_async_blocking(
|
||||
async_read_file(file_path, runtime, offset, limit)
|
||||
)
|
||||
return run_async_blocking(async_read_file(file_path, runtime, offset, limit))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="read_file",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
async def cloud_rm(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
validated: str,
|
||||
) -> Command | str:
|
||||
|
|
@ -31,8 +31,7 @@ async def cloud_rm(
|
|||
return f"Error: refusing to rm '{validated}'."
|
||||
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||
return (
|
||||
"Error: cloud rm must target a path under /documents/ "
|
||||
f"(got '{validated}')."
|
||||
f"Error: cloud rm must target a path under /documents/ (got '{validated}')."
|
||||
)
|
||||
|
||||
anon = runtime.state.get("kb_anon_doc") or {}
|
||||
|
|
@ -41,14 +40,10 @@ async def cloud_rm(
|
|||
|
||||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
if validated in staged_dirs:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
)
|
||||
return f"Error: '{validated}' is a directory. Use rmdir for empty directories."
|
||||
pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or [])
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes
|
||||
):
|
||||
return f"Error: '{validated}' is already queued for rmdir."
|
||||
|
||||
|
|
@ -57,14 +52,11 @@ async def cloud_rm(
|
|||
children = await backend.als_info(validated)
|
||||
if children:
|
||||
return (
|
||||
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||
"empty directories."
|
||||
f"Error: '{validated}' is a directory. Use rmdir for empty directories."
|
||||
)
|
||||
|
||||
pending_deletes = list(runtime.state.get("pending_deletes") or [])
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated for d in pending_deletes
|
||||
):
|
||||
if any(isinstance(d, dict) and d.get("path") == validated for d in pending_deletes):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
files_state = runtime.state.get("files") or {}
|
||||
|
|
@ -93,8 +85,7 @@ async def cloud_rm(
|
|||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged delete of '{validated}' (will commit at "
|
||||
"end of turn)."
|
||||
f"Staged delete of '{validated}' (will commit at end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
|
@ -114,7 +105,7 @@ async def cloud_rm(
|
|||
|
||||
|
||||
async def desktop_rm(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
validated: str,
|
||||
) -> Command | str:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_rm_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_rm_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_rm(
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
async def cloud_rmdir(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
validated: str,
|
||||
) -> Command | str:
|
||||
|
|
@ -49,8 +49,7 @@ async def cloud_rmdir(
|
|||
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||
pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or [])
|
||||
if any(
|
||||
isinstance(d, dict) and d.get("path") == validated
|
||||
for d in pending_dir_deletes
|
||||
isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes
|
||||
):
|
||||
return f"'{validated}' is already queued for deletion."
|
||||
|
||||
|
|
@ -61,11 +60,7 @@ async def cloud_rmdir(
|
|||
if isinstance(backend, KBPostgresBackend):
|
||||
children = list(await backend.als_info(validated))
|
||||
|
||||
if (
|
||||
isinstance(backend, KBPostgresBackend)
|
||||
and not children
|
||||
and not exists_in_staged
|
||||
):
|
||||
if isinstance(backend, KBPostgresBackend) and not children and not exists_in_staged:
|
||||
loaded = await backend._load_file_data(validated)
|
||||
if loaded is not None:
|
||||
return f"Error: '{validated}' is a file. Use rm to delete files."
|
||||
|
|
@ -79,9 +74,7 @@ async def cloud_rmdir(
|
|||
return f"Error: directory '{validated}' not found."
|
||||
|
||||
if children:
|
||||
return (
|
||||
f"Error: directory '{validated}' is not empty. Remove contents first."
|
||||
)
|
||||
return f"Error: directory '{validated}' is not empty. Remove contents first."
|
||||
|
||||
if exists_in_staged:
|
||||
rest = [d for d in staged_dirs if d != validated]
|
||||
|
|
@ -109,8 +102,7 @@ async def cloud_rmdir(
|
|||
"messages": [
|
||||
ToolMessage(
|
||||
content=(
|
||||
f"Staged rmdir of '{validated}' (will commit "
|
||||
"at end of turn)."
|
||||
f"Staged rmdir of '{validated}' (will commit at end of turn)."
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
|
@ -120,7 +112,7 @@ async def cloud_rmdir(
|
|||
|
||||
|
||||
async def desktop_rmdir(
|
||||
mw: "SurfSenseFilesystemMiddleware",
|
||||
mw: SurfSenseFilesystemMiddleware,
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
validated: str,
|
||||
) -> Command | str:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_rmdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_rmdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_rmdir(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
||||
def create_write_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
|
||||
description = select_description(mw._filesystem_mode)
|
||||
|
||||
async def async_write_file(
|
||||
|
|
@ -73,9 +73,7 @@ def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
|
|||
content: Annotated[str, "Text content to write to the file."],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
) -> Command | str:
|
||||
return run_async_blocking(
|
||||
async_write_file(file_path, content, runtime)
|
||||
)
|
||||
return run_async_blocking(async_write_file(file_path, content, runtime))
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="write_file",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,272 @@
|
|||
"""Real-graph parallel HITL across both approval kinds — the keystone regression.
|
||||
|
||||
Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls``
|
||||
+ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only
|
||||
recognized middleware-gated approvals (LC HITL shape from
|
||||
``HumanInTheLoopMiddleware``). Self-gated approvals from
|
||||
``request_approval`` and middleware-gated permission asks from
|
||||
``PermissionMiddleware`` both used the SurfSense-specific
|
||||
``{type, action, context}`` shape, so when the orchestrator dispatched
|
||||
two parallel ``task`` calls — one self-gated, one middleware-gated — only
|
||||
one interrupt was visible to the routing layer and resume crashed with
|
||||
``Decision count mismatch``.
|
||||
|
||||
This test fans out two real subagents via ``Send``: one calls
|
||||
``request_approval`` (self-gated), the other calls
|
||||
``request_permission_decision`` (middleware-gated). Both pause; the routing
|
||||
layer must see TWO LC HITL interrupts, slice the decisions by
|
||||
``tool_call_id``, key by ``Interrupt.id``, and resume both branches with
|
||||
their per-slice payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
# ``add_messages`` is mandatory: parallel ``Send`` branches both append
|
||||
# to ``messages`` in the same superstep; without a reducer langgraph
|
||||
# raises ``InvalidUpdateError``.
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_self_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_approval`` (self-gated path)."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com"},
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "self_gated",
|
||||
"decision_type": result.decision_type,
|
||||
"params": result.params,
|
||||
"rejected": result.rejected,
|
||||
},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_middleware_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_permission_decision`` (middleware-gated path)."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{"kind": "middleware_gated", "decision": decision},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_mixed_task_tool(checkpointer: InMemorySaver):
|
||||
"""Two subagents, one per approval kind, registered under distinct names."""
|
||||
return build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "self-gated-agent",
|
||||
"description": "uses request_approval",
|
||||
"runnable": _build_self_gated_subagent(checkpointer),
|
||||
},
|
||||
{
|
||||
"name": "middleware-gated-agent",
|
||||
"description": "uses request_permission_decision",
|
||||
"runnable": _build_middleware_gated_subagent(checkpointer),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _parent_dispatching_one_of_each(
|
||||
task_tool, *, tcid_self: str, tcid_mw: str, checkpointer
|
||||
):
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send(
|
||||
"call_task",
|
||||
{"tcid": tcid_self, "desc": "approve email", "subtype": "self-gated-agent"},
|
||||
),
|
||||
Send(
|
||||
"call_task",
|
||||
{
|
||||
"tcid": tcid_mw,
|
||||
"desc": "approve rm",
|
||||
"subtype": "middleware-gated-agent",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly():
|
||||
"""Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently."""
|
||||
checkpointer = InMemorySaver()
|
||||
task_tool = _build_mixed_task_tool(checkpointer)
|
||||
tcid_self = "tcid-self-gated"
|
||||
tcid_mw = "tcid-middleware-gated"
|
||||
parent = _parent_dispatching_one_of_each(
|
||||
task_tool,
|
||||
tcid_self=tcid_self,
|
||||
tcid_mw=tcid_mw,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "mixed-parallel"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused = await parent.aget_state(config)
|
||||
assert len(paused.interrupts) == 2, (
|
||||
"fixture broken: expected one paused interrupt per approval kind"
|
||||
)
|
||||
|
||||
# Both interrupts must speak the same wire shape — the whole point of
|
||||
# the unification. If either one regresses to the legacy SurfSense shape
|
||||
# ``collect_pending_tool_calls`` would silently skip it and the count
|
||||
# below would be 1.
|
||||
pending = collect_pending_tool_calls(paused)
|
||||
assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, (
|
||||
f"REGRESSION: not all interrupt kinds reached the routing layer; "
|
||||
f"got {pending!r}"
|
||||
)
|
||||
|
||||
# Verify the actual wire payloads carry the LC HITL standard fields
|
||||
# (extra defensive assertion against partial regressions where one
|
||||
# path stamps tool_call_id but reverts the body shape).
|
||||
interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts}
|
||||
assert interrupt_types == {"gmail_email_send", "permission_ask"}
|
||||
|
||||
# Resume order: same order the SSE stream would emit (interrupts list).
|
||||
decision_self = {"type": "approve"}
|
||||
decision_mw = {"type": "always"}
|
||||
flat_decisions = [
|
||||
# Match `pending` order.
|
||||
decision_self if pending[0][0] == tcid_self else decision_mw,
|
||||
decision_mw if pending[0][0] == tcid_self else decision_self,
|
||||
]
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
|
||||
assert len(lg_resume_map) == 2
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final = await parent.aget_state(config)
|
||||
assert not final.interrupts, (
|
||||
f"expected both branches resumed, but state still has interrupts: "
|
||||
f"{final.interrupts!r}"
|
||||
)
|
||||
|
||||
# Each subagent must have received its own slice — verify by inspecting
|
||||
# the JSON-serialized result messages.
|
||||
payloads: list[dict] = []
|
||||
for msg in final.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
payloads.append(json.loads(content))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self_payloads = [p for p in payloads if p.get("kind") == "self_gated"]
|
||||
mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"]
|
||||
assert len(self_payloads) == 1, (
|
||||
f"self-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
assert len(mw_payloads) == 1, (
|
||||
f"middleware-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
|
||||
# Self-gated approve → HITLResult(decision_type="approve", rejected=False).
|
||||
assert self_payloads[0]["decision_type"] == "approve"
|
||||
assert self_payloads[0]["rejected"] is False
|
||||
|
||||
# Middleware-gated always → canonical permission shape with always.
|
||||
assert mw_payloads[0]["decision"] == {"decision_type": "always"}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the
|
||||
permission middleware previously fired the SurfSense-specific
|
||||
``{type, action, context}`` shape, which the parallel-HITL routing layer
|
||||
does not recognize. Standardizing on LC HITL keeps every approval kind on
|
||||
one routing path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver):
|
||||
"""Real graph whose only node delegates to the permission ask primitive."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {"final_decision": decision}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_ask_payload_uses_lc_hitl_shape():
|
||||
"""The permission middleware now speaks the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
value = snap.interrupts[0].value
|
||||
|
||||
assert value.get("action_requests") == [
|
||||
{"name": "rm", "args": {"path": "/tmp/file"}}
|
||||
], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}"
|
||||
review = value.get("review_configs")
|
||||
assert isinstance(review, list) and len(review) == 1
|
||||
# ``always`` must be in the palette so the FE can render the promote button.
|
||||
assert "always" in review[0]["allowed_decisions"]
|
||||
assert value.get("interrupt_type") == "permission_ask"
|
||||
# SurfSense context rides through verbatim for FE explainability.
|
||||
assert value["context"]["patterns"] == ["rm/*"]
|
||||
assert value["context"]["always"] == ["rm/*"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_approve_envelope_returns_once_decision():
|
||||
"""``approve`` from the LC envelope projects to permission-domain ``once``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-once"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "once"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_always_envelope_projects_to_always():
|
||||
"""``always`` reply must project unchanged so the middleware can promote the rule."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "always"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "always"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_reject_and_feedback_carries_feedback_through():
|
||||
"""Reject feedback must survive normalization for ``CorrectedError`` to fire downstream."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={
|
||||
"decisions": [{"type": "reject", "feedback": "use the trash bin"}]
|
||||
}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {
|
||||
"decision_type": "reject",
|
||||
"feedback": "use the trash bin",
|
||||
}
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
|
||||
|
||||
The KB unification swap (legacy ``interrupt_on`` map → KB-owned ``Ruleset``
|
||||
threaded through ``build_permission_mw(extra_rulesets=...)``) must produce
|
||||
*exactly one* interrupt per destructive FS call, in LC HITL shape, even
|
||||
when ``enable_permission`` is False — destructive ops always ask.
|
||||
|
||||
We exercise the production factory and a real ``PermissionMiddleware`` on a
|
||||
real ``StateGraph`` so the test catches regressions in factory gating,
|
||||
ruleset layering, and interrupt emission together.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
def _kb_style_ruleset() -> Ruleset:
|
||||
"""Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it.
|
||||
|
||||
Importing the agent module pulls in deepagents and prompts; this test
|
||||
is about the factory + middleware contract, not KB wiring.
|
||||
"""
|
||||
return Ruleset(
|
||||
origin="knowledge_base",
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph_with_permission_middleware(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
extra_rulesets: list[Ruleset] | None,
|
||||
checkpointer: InMemorySaver,
|
||||
):
|
||||
"""Compile a one-node graph that emits a tool call for ``rm`` and
|
||||
routes through the production ``PermissionMiddleware``.
|
||||
|
||||
The node returns an ``AIMessage`` with a tool call. The middleware's
|
||||
``after_model`` hook intercepts and (if a rule says ``ask``) raises
|
||||
a ``GraphInterrupt`` carrying the LC HITL payload.
|
||||
"""
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=extra_rulesets)
|
||||
|
||||
def node(_state: _State) -> dict[str, Any]:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/tmp/foo"},
|
||||
"id": "call-rm-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
return {"messages": [msg]}
|
||||
|
||||
def after_node(state: _State) -> dict[str, Any] | None:
|
||||
if pm is None:
|
||||
return None
|
||||
# PermissionMiddleware._process ignores runtime; the test never relies
|
||||
# on the runtime context, so passing None keeps the harness lean.
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", node)
|
||||
g.add_node("permission", after_node)
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=checkpointer), pm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off():
|
||||
"""KB ruleset: ``rm`` must ask once even with ``enable_permission=False``.
|
||||
|
||||
This is the keystone of the unification: the legacy ``interrupt_on``
|
||||
map fired regardless of ``enable_permission``, so the migrated rules
|
||||
must too. Otherwise users could opt out of "ask before rm".
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, pm = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
assert pm is not None, "extras must force the middleware on"
|
||||
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"REGRESSION: KB ruleset should raise exactly one interrupt; got "
|
||||
f"{[i.value for i in snap.interrupts]!r}"
|
||||
)
|
||||
payload = snap.interrupts[0].value
|
||||
requests = payload.get("action_requests")
|
||||
assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], (
|
||||
f"interrupt must carry the rm call in LC HITL shape; got {payload!r}"
|
||||
)
|
||||
assert payload.get("interrupt_type") == "permission_ask"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_resume_with_approve_lets_rm_through():
|
||||
"""Resume with ``approve`` → call kept; the model continues normally."""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, _ = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.next == (), "graph must complete after approve"
|
||||
last_ai = next(
|
||||
(m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
assert last_ai is not None
|
||||
assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], (
|
||||
"approved rm call must remain on the AIMessage so the tool can run"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_extras_with_permission_off_skips_middleware_entirely():
|
||||
"""No extras + permission off → factory returns ``None`` (no engine).
|
||||
|
||||
The legacy gating is preserved when no caller asks for rules: nothing
|
||||
runs, nothing pauses.
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=None)
|
||||
assert pm is None
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""Regression: ``request_approval`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Before this fix, self-gated approvals fired the SurfSense-specific
|
||||
``{type, action, context}`` shape which the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls``) does not recognize. In a parallel HITL
|
||||
scenario where one subagent used self-gated approvals (e.g. Gmail send)
|
||||
and another used middleware-gated approvals (e.g. Linear via
|
||||
``HumanInTheLoopMiddleware``), the routing layer would silently skip the
|
||||
self-gated interrupt and crash on resume with ``Decision count mismatch``.
|
||||
|
||||
This test pins the wire contract by running ``request_approval`` inside a
|
||||
real ``StateGraph`` and asserting the paused parent observes the LC HITL
|
||||
shape (``action_requests``, ``review_configs``, ``interrupt_type``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision_type: str
|
||||
final_params: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_approval(checkpointer: InMemorySaver):
|
||||
"""A real graph whose only node delegates to ``request_approval``."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com", "subject": "hi"},
|
||||
context={"account": "alice@gmail.com"},
|
||||
)
|
||||
return {
|
||||
"final_decision_type": result.decision_type,
|
||||
"final_params": result.params,
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_interrupt_uses_lc_hitl_action_requests_shape():
|
||||
"""The paused interrupt must speak the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"expected one paused interrupt, got {len(snap.interrupts)}"
|
||||
)
|
||||
value = snap.interrupts[0].value
|
||||
assert isinstance(value, dict)
|
||||
|
||||
# Standard LC HITL fields the routing layer reads.
|
||||
assert value.get("action_requests") == [
|
||||
{
|
||||
"name": "send_gmail_email",
|
||||
"args": {"to": "alice@example.com", "subject": "hi"},
|
||||
}
|
||||
], (
|
||||
"REGRESSION: self-gated approval reverted to legacy SurfSense shape; "
|
||||
f"got {value!r}"
|
||||
)
|
||||
assert value.get("review_configs") == [
|
||||
{
|
||||
"action_name": "send_gmail_email",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
]
|
||||
assert value.get("interrupt_type") == "gmail_email_send", (
|
||||
"FE card discriminator must travel as ``interrupt_type``."
|
||||
)
|
||||
assert value.get("context") == {"account": "alice@gmail.com"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args():
|
||||
"""Edit reply via the LC envelope must round-trip into ``HITLResult.params``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-resume"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
edited = {"to": "alice@example.com", "subject": "EDITED"}
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={
|
||||
"decisions": [
|
||||
{"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}}
|
||||
]
|
||||
}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "edit"
|
||||
assert final.values.get("final_params") == edited
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_envelope_returns_rejected_hitl_result():
|
||||
"""Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "reject"
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
"""Unit contract for the unified LC HITL wire format.
|
||||
|
||||
Both the self-gated approval primitive (``request_approval``) and the
|
||||
middleware-gated permission ask (``PermissionMiddleware``) must serialize
|
||||
to the same wire shape so the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` +
|
||||
``build_lg_resume_map``) sees one format.
|
||||
|
||||
These tests pin the shape:
|
||||
|
||||
- Builder always emits ``action_requests`` (1 entry) + ``review_configs``
|
||||
+ ``interrupt_type``; ``context`` rides through verbatim when present.
|
||||
- Parser tolerates the standard LC envelope, bare scalar strings, and
|
||||
unrecognized shapes (failing closed to ``reject``).
|
||||
- Edited args round-trip through both nested (``edited_action.args``) and
|
||||
flat (``args``) shapes without inventing values for the empty case.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_EDIT,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_ALWAYS,
|
||||
build_lc_hitl_payload,
|
||||
parse_lc_envelope,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildLcHitlPayload:
|
||||
def test_minimal_payload_has_one_action_request_and_one_review_config(self):
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="send_email",
|
||||
args={"to": "x@y.z"},
|
||||
allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
interrupt_type="gmail_email_send",
|
||||
)
|
||||
assert payload["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "x@y.z"}}
|
||||
]
|
||||
assert payload["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
"allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
}
|
||||
]
|
||||
assert payload["interrupt_type"] == "gmail_email_send"
|
||||
assert "context" not in payload, "context must be omitted when not provided"
|
||||
|
||||
def test_none_args_normalized_to_empty_dict(self):
|
||||
"""FE expects a stable shape; ``None`` would crash card rendering."""
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="ping",
|
||||
args=None, # type: ignore[arg-type]
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="self_gated",
|
||||
)
|
||||
assert payload["action_requests"][0]["args"] == {}
|
||||
|
||||
def test_description_attached_only_when_provided(self):
|
||||
with_desc = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
description="please review",
|
||||
)
|
||||
without = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
)
|
||||
assert with_desc["action_requests"][0]["description"] == "please review"
|
||||
assert "description" not in without["action_requests"][0]
|
||||
|
||||
def test_context_passed_through_verbatim(self):
|
||||
ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]}
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp"},
|
||||
allowed_decisions=[
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_ALWAYS,
|
||||
],
|
||||
interrupt_type="permission_ask",
|
||||
context=ctx,
|
||||
)
|
||||
assert payload["context"] == ctx
|
||||
|
||||
def test_allowed_decisions_list_is_copied_not_aliased(self):
|
||||
"""A caller mutating their original list must not corrupt the payload."""
|
||||
decisions = [LC_DECISION_APPROVE]
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=decisions,
|
||||
interrupt_type="x",
|
||||
)
|
||||
decisions.append(LC_DECISION_REJECT)
|
||||
assert payload["review_configs"][0]["allowed_decisions"] == [LC_DECISION_APPROVE]
|
||||
|
||||
|
||||
class TestParseLcEnvelope:
|
||||
def test_standard_lc_envelope_returns_typed_decision(self):
|
||||
parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]})
|
||||
assert parsed.decision_type == "approve"
|
||||
assert parsed.edited_args is None
|
||||
assert parsed.message is None
|
||||
|
||||
def test_bare_scalar_string_passes_through_lowercased(self):
|
||||
assert parse_lc_envelope("ALWAYS").decision_type == "always"
|
||||
assert parse_lc_envelope("once").decision_type == "once"
|
||||
|
||||
def test_non_dict_non_string_collapses_to_reject(self):
|
||||
"""Failing closed: ambiguous input must never proceed."""
|
||||
assert parse_lc_envelope(42).decision_type == "reject"
|
||||
assert parse_lc_envelope(None).decision_type == "reject"
|
||||
assert parse_lc_envelope(["bogus"]).decision_type == "reject"
|
||||
|
||||
def test_missing_decision_type_collapses_to_reject(self):
|
||||
assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject"
|
||||
assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject"
|
||||
|
||||
def test_edit_extracts_nested_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{
|
||||
"decisions": [
|
||||
{
|
||||
"type": LC_DECISION_EDIT,
|
||||
"edited_action": {"args": {"to": "edited@y.z"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert parsed.decision_type == "edit"
|
||||
assert parsed.edited_args == {"to": "edited@y.z"}
|
||||
|
||||
def test_edit_falls_back_to_flat_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "args": {"k": "v"}}]}
|
||||
)
|
||||
assert parsed.edited_args == {"k": "v"}
|
||||
|
||||
def test_edit_with_empty_args_yields_none_edited(self):
|
||||
"""Empty edited_args means "no edits" — caller treats as plain approve."""
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "edited_action": {"args": {}}}]}
|
||||
)
|
||||
assert parsed.edited_args is None
|
||||
|
||||
def test_message_picked_from_either_feedback_or_message_field(self):
|
||||
with_feedback = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "feedback": "no thanks"}]}
|
||||
)
|
||||
with_message = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": "no thanks"}]}
|
||||
)
|
||||
assert with_feedback.message == "no thanks"
|
||||
assert with_message.message == "no thanks"
|
||||
|
||||
def test_blank_message_treated_as_absent(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": " "}]}
|
||||
)
|
||||
assert parsed.message is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue