multi_agent_chat: real-graph regressions for unified HITL paths + format pass

This commit is contained in:
CREDO23 2026-05-14 17:41:24 +02:00
parent adb52fb575
commit 0723702320
34 changed files with 920 additions and 90 deletions

View file

@ -14,7 +14,7 @@ from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.middleware.stack import (
build_main_agent_deepagent_middleware,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import (
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import (
ToolsPermissions,
)
from app.agents.new_chat.context import SurfSenseContextSchema

View file

@ -10,7 +10,7 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import ToolsPermissions
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,

View file

@ -49,7 +49,7 @@ def build_main_agent_system_prompt(
custom_system_instructions: str | None = None,
use_default_system_instructions: bool = True,
citations_enabled: bool = True,
model_name: str | None = None, # noqa: ARG001 — kept for caller compatibility
model_name: str | None = None,
) -> str:
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
visibility = thread_visibility or ChatVisibility.PRIVATE
@ -62,7 +62,9 @@ def build_main_agent_system_prompt(
if custom_system_instructions and custom_system_instructions.strip():
parts.append(
"\n" + custom_system_instructions.format(resolved_today=resolved_today) + "\n"
"\n"
+ custom_system_instructions.format(resolved_today=resolved_today)
+ "\n"
)
if use_default_system_instructions:

View file

@ -61,9 +61,7 @@ def slice_decisions_by_tool_call(
routed: dict[str, dict[str, Any]] = {}
cursor = 0
for tool_call_id, action_count in pending_list:
routed[tool_call_id] = {
"decisions": decisions[cursor : cursor + action_count]
}
routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]}
cursor += action_count
return routed

View file

@ -22,7 +22,7 @@ if TYPE_CHECKING:
def check_cloud_write_namespace(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
path: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str | None:

View file

@ -25,7 +25,7 @@ if TYPE_CHECKING:
def current_cwd(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None
@ -35,7 +35,7 @@ def current_cwd(
def get_contract_suggested_path(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
"""Read the planner's suggested write path; otherwise default to ``notes.md``."""
@ -47,7 +47,7 @@ def get_contract_suggested_path(
def resolve_relative(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
path: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
@ -63,7 +63,7 @@ def resolve_relative(
def resolve_write_target_path(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
file_path: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
@ -77,7 +77,7 @@ def resolve_write_target_path(
def resolve_move_target_path(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
file_path: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
@ -91,7 +91,7 @@ def resolve_move_target_path(
def resolve_list_target_path(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
path: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:
@ -105,7 +105,7 @@ def resolve_list_target_path(
def normalize_local_mount_path(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
candidate: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> str:

View file

@ -9,9 +9,7 @@ from .common import HEADER, SANDBOX_ADDENDUM
from .desktop import BODY as DESKTOP_BODY
def build_system_prompt(
mode: FilesystemMode, *, sandbox_available: bool
) -> str:
def build_system_prompt(mode: FilesystemMode, *, sandbox_available: bool) -> str:
"""Assemble the FS prompt: common header + mode body + optional sandbox section."""
body = CLOUD_BODY if mode == FilesystemMode.CLOUD else DESKTOP_BODY
base = HEADER + body

View file

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_cd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_cd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_cd(

View file

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_edit_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_edit_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_edit_file(

View file

@ -36,7 +36,7 @@ def wrap_as_python(code: str) -> str:
async def execute_in_sandbox(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
command: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
timeout: int | None,
@ -59,14 +59,12 @@ async def execute_in_sandbox(
try:
return await _try_sandbox_execute(mw, command, runtime, timeout)
except Exception:
logger.exception(
"Sandbox retry also failed for thread %s", mw._thread_id
)
logger.exception("Sandbox retry also failed for thread %s", mw._thread_id)
return "Error: Code execution is temporarily unavailable. Please try again."
async def _try_sandbox_execute(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
command: str,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
timeout: int | None,

View file

@ -17,13 +17,11 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_execute_code_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
def sync_execute_code(
command: Annotated[
str, "Python code to execute. Use print() to see output."
],
command: Annotated[str, "Python code to execute. Use print() to see output."],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
timeout: Annotated[
int | None,
@ -35,14 +33,10 @@ def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
return f"Error: timeout must be non-negative, got {timeout}."
if timeout > MAX_EXECUTE_TIMEOUT:
return f"Error: timeout {timeout}s exceeds maximum ({MAX_EXECUTE_TIMEOUT}s)."
return run_async_blocking(
execute_in_sandbox(mw, command, runtime, timeout)
)
return run_async_blocking(execute_in_sandbox(mw, command, runtime, timeout))
async def async_execute_code(
command: Annotated[
str, "Python code to execute. Use print() to see output."
],
command: Annotated[str, "Python code to execute. Use print() to see output."],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
timeout: Annotated[
int | None,

View file

@ -20,7 +20,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_list_tree_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_list_tree_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_list_tree(

View file

@ -19,7 +19,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_ls_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_ls_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_ls(

View file

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_mkdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_mkdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_mkdir(

View file

@ -18,7 +18,7 @@ if TYPE_CHECKING:
async def cloud_move_file(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
source: str,
dest: str,
@ -39,8 +39,7 @@ async def cloud_move_file(
)
if not source.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud move_file source must be under /documents/ (got "
f"'{source}')."
f"Error: cloud move_file source must be under /documents/ (got '{source}')."
)
if not dest.startswith(DOCUMENTS_ROOT + "/"):
return (
@ -89,9 +88,7 @@ async def cloud_move_file(
],
"messages": [
ToolMessage(
content=(
f"Moved '{source}' to '{dest}' (will commit at end of turn)."
),
content=(f"Moved '{source}' to '{dest}' (will commit at end of turn)."),
tool_call_id=runtime.tool_call_id,
)
],

View file

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_move_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_move_file(
@ -85,9 +85,7 @@ def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
] = False,
) -> Command | str:
return run_async_blocking(
async_move_file(
source_path, destination_path, runtime, overwrite=overwrite
)
async_move_file(source_path, destination_path, runtime, overwrite=overwrite)
)
return StructuredTool.from_function(

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_pwd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_pwd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
def sync_pwd(

View file

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_read_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_read_file(
@ -90,9 +90,7 @@ def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
"Maximum number of lines to read.",
] = 100,
) -> Command | str:
return run_async_blocking(
async_read_file(file_path, runtime, offset, limit)
)
return run_async_blocking(async_read_file(file_path, runtime, offset, limit))
return StructuredTool.from_function(
name="read_file",

View file

@ -22,7 +22,7 @@ if TYPE_CHECKING:
async def cloud_rm(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
validated: str,
) -> Command | str:
@ -31,8 +31,7 @@ async def cloud_rm(
return f"Error: refusing to rm '{validated}'."
if not validated.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud rm must target a path under /documents/ "
f"(got '{validated}')."
f"Error: cloud rm must target a path under /documents/ (got '{validated}')."
)
anon = runtime.state.get("kb_anon_doc") or {}
@ -41,14 +40,10 @@ async def cloud_rm(
staged_dirs = list(runtime.state.get("staged_dirs") or [])
if validated in staged_dirs:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
)
return f"Error: '{validated}' is a directory. Use rmdir for empty directories."
pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or [])
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes
):
return f"Error: '{validated}' is already queued for rmdir."
@ -57,14 +52,11 @@ async def cloud_rm(
children = await backend.als_info(validated)
if children:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
f"Error: '{validated}' is a directory. Use rmdir for empty directories."
)
pending_deletes = list(runtime.state.get("pending_deletes") or [])
if any(
isinstance(d, dict) and d.get("path") == validated for d in pending_deletes
):
if any(isinstance(d, dict) and d.get("path") == validated for d in pending_deletes):
return f"'{validated}' is already queued for deletion."
files_state = runtime.state.get("files") or {}
@ -93,8 +85,7 @@ async def cloud_rm(
"messages": [
ToolMessage(
content=(
f"Staged delete of '{validated}' (will commit at "
"end of turn)."
f"Staged delete of '{validated}' (will commit at end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
@ -114,7 +105,7 @@ async def cloud_rm(
async def desktop_rm(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
validated: str,
) -> Command | str:

View file

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_rm_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_rm_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_rm(

View file

@ -26,7 +26,7 @@ if TYPE_CHECKING:
async def cloud_rmdir(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
validated: str,
) -> Command | str:
@ -49,8 +49,7 @@ async def cloud_rmdir(
staged_dirs = list(runtime.state.get("staged_dirs") or [])
pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or [])
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes
):
return f"'{validated}' is already queued for deletion."
@ -61,11 +60,7 @@ async def cloud_rmdir(
if isinstance(backend, KBPostgresBackend):
children = list(await backend.als_info(validated))
if (
isinstance(backend, KBPostgresBackend)
and not children
and not exists_in_staged
):
if isinstance(backend, KBPostgresBackend) and not children and not exists_in_staged:
loaded = await backend._load_file_data(validated)
if loaded is not None:
return f"Error: '{validated}' is a file. Use rm to delete files."
@ -79,9 +74,7 @@ async def cloud_rmdir(
return f"Error: directory '{validated}' not found."
if children:
return (
f"Error: directory '{validated}' is not empty. Remove contents first."
)
return f"Error: directory '{validated}' is not empty. Remove contents first."
if exists_in_staged:
rest = [d for d in staged_dirs if d != validated]
@ -109,8 +102,7 @@ async def cloud_rmdir(
"messages": [
ToolMessage(
content=(
f"Staged rmdir of '{validated}' (will commit "
"at end of turn)."
f"Staged rmdir of '{validated}' (will commit at end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
@ -120,7 +112,7 @@ async def cloud_rmdir(
async def desktop_rmdir(
mw: "SurfSenseFilesystemMiddleware",
mw: SurfSenseFilesystemMiddleware,
runtime: ToolRuntime[None, SurfSenseFilesystemState],
validated: str,
) -> Command | str:

View file

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_rmdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_rmdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_rmdir(

View file

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from ...middleware import SurfSenseFilesystemMiddleware
def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
def create_write_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool:
description = select_description(mw._filesystem_mode)
async def async_write_file(
@ -73,9 +73,7 @@ def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool:
content: Annotated[str, "Text content to write to the file."],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
return run_async_blocking(
async_write_file(file_path, content, runtime)
)
return run_async_blocking(async_write_file(file_path, content, runtime))
return StructuredTool.from_function(
name="write_file",

View file

@ -0,0 +1,272 @@
"""Real-graph parallel HITL across both approval kinds — the keystone regression.
Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls``
+ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only
recognized middleware-gated approvals (LC HITL shape from
``HumanInTheLoopMiddleware``). Self-gated approvals from
``request_approval`` and middleware-gated permission asks from
``PermissionMiddleware`` both used the SurfSense-specific
``{type, action, context}`` shape, so when the orchestrator dispatched
two parallel ``task`` calls one self-gated, one middleware-gated only
one interrupt was visible to the routing layer and resume crashed with
``Decision count mismatch``.
This test fans out two real subagents via ``Send``: one calls
``request_approval`` (self-gated), the other calls
``request_permission_decision`` (middleware-gated). Both pause; the routing
layer must see TWO LC HITL interrupts, slice the decisions by
``tool_call_id``, key by ``Interrupt.id``, and resume both branches with
their per-slice payload.
"""
from __future__ import annotations
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
request_permission_decision,
)
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
from app.agents.new_chat.permissions import Rule
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
# ``add_messages`` is mandatory: parallel ``Send`` branches both append
# to ``messages`` in the same superstep; without a reducer langgraph
# raises ``InvalidUpdateError``.
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
def _build_self_gated_subagent(checkpointer: InMemorySaver):
"""Subagent that pauses via ``request_approval`` (self-gated path)."""
def gate_node(_state):
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com"},
)
return {
"messages": [
AIMessage(
content=json.dumps(
{
"kind": "self_gated",
"decision_type": result.decision_type,
"params": result.params,
"rejected": result.rejected,
},
sort_keys=True,
)
)
]
}
g = StateGraph(_SubState)
g.add_node("gate", gate_node)
g.add_edge(START, "gate")
g.add_edge("gate", END)
return g.compile(checkpointer=checkpointer)
def _build_middleware_gated_subagent(checkpointer: InMemorySaver):
"""Subagent that pauses via ``request_permission_decision`` (middleware-gated path)."""
def perm_node(_state):
decision = request_permission_decision(
tool_name="rm",
args={"path": "/tmp/file"},
patterns=["rm/*"],
rules=[Rule(permission="rm", pattern="*", action="ask")],
emit_interrupt=True,
)
return {
"messages": [
AIMessage(
content=json.dumps(
{"kind": "middleware_gated", "decision": decision},
sort_keys=True,
)
)
]
}
g = StateGraph(_SubState)
g.add_node("perm", perm_node)
g.add_edge(START, "perm")
g.add_edge("perm", END)
return g.compile(checkpointer=checkpointer)
def _build_mixed_task_tool(checkpointer: InMemorySaver):
"""Two subagents, one per approval kind, registered under distinct names."""
return build_task_tool_with_parent_config(
[
{
"name": "self-gated-agent",
"description": "uses request_approval",
"runnable": _build_self_gated_subagent(checkpointer),
},
{
"name": "middleware-gated-agent",
"description": "uses request_permission_decision",
"runnable": _build_middleware_gated_subagent(checkpointer),
},
]
)
def _parent_dispatching_one_of_each(
task_tool, *, tcid_self: str, tcid_mw: str, checkpointer
):
def fanout_edge(_state) -> list[Send]:
return [
Send(
"call_task",
{"tcid": tcid_self, "desc": "approve email", "subtype": "self-gated-agent"},
),
Send(
"call_task",
{
"tcid": tcid_mw,
"desc": "approve rm",
"subtype": "middleware-gated-agent",
},
),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly():
"""Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently."""
checkpointer = InMemorySaver()
task_tool = _build_mixed_task_tool(checkpointer)
tcid_self = "tcid-self-gated"
tcid_mw = "tcid-middleware-gated"
parent = _parent_dispatching_one_of_each(
task_tool,
tcid_self=tcid_self,
tcid_mw=tcid_mw,
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "mixed-parallel"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused = await parent.aget_state(config)
assert len(paused.interrupts) == 2, (
"fixture broken: expected one paused interrupt per approval kind"
)
# Both interrupts must speak the same wire shape — the whole point of
# the unification. If either one regresses to the legacy SurfSense shape
# ``collect_pending_tool_calls`` would silently skip it and the count
# below would be 1.
pending = collect_pending_tool_calls(paused)
assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, (
f"REGRESSION: not all interrupt kinds reached the routing layer; "
f"got {pending!r}"
)
# Verify the actual wire payloads carry the LC HITL standard fields
# (extra defensive assertion against partial regressions where one
# path stamps tool_call_id but reverts the body shape).
interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts}
assert interrupt_types == {"gmail_email_send", "permission_ask"}
# Resume order: same order the SSE stream would emit (interrupts list).
decision_self = {"type": "approve"}
decision_mw = {"type": "always"}
flat_decisions = [
# Match `pending` order.
decision_self if pending[0][0] == tcid_self else decision_mw,
decision_mw if pending[0][0] == tcid_self else decision_self,
]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
assert len(lg_resume_map) == 2
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
await parent.ainvoke(Command(resume=lg_resume_map), config)
final = await parent.aget_state(config)
assert not final.interrupts, (
f"expected both branches resumed, but state still has interrupts: "
f"{final.interrupts!r}"
)
# Each subagent must have received its own slice — verify by inspecting
# the JSON-serialized result messages.
payloads: list[dict] = []
for msg in final.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
try:
payloads.append(json.loads(content))
except json.JSONDecodeError:
pass
self_payloads = [p for p in payloads if p.get("kind") == "self_gated"]
mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"]
assert len(self_payloads) == 1, (
f"self-gated subagent did not complete; payloads: {payloads!r}"
)
assert len(mw_payloads) == 1, (
f"middleware-gated subagent did not complete; payloads: {payloads!r}"
)
# Self-gated approve → HITLResult(decision_type="approve", rejected=False).
assert self_payloads[0]["decision_type"] == "approve"
assert self_payloads[0]["rejected"] is False
# Middleware-gated always → canonical permission shape with always.
assert mw_payloads[0]["decision"] == {"decision_type": "always"}

View file

@ -0,0 +1,125 @@
"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape.
Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the
permission middleware previously fired the SurfSense-specific
``{type, action, context}`` shape, which the parallel-HITL routing layer
does not recognize. Standardizing on LC HITL keeps every approval kind on
one routing path.
"""
from __future__ import annotations
import pytest
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
request_permission_decision,
)
from app.agents.new_chat.permissions import Rule
class _State(TypedDict, total=False):
messages: list
final_decision: dict
def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver):
"""Real graph whose only node delegates to the permission ask primitive."""
def perm_node(_state):
decision = request_permission_decision(
tool_name="rm",
args={"path": "/tmp/file"},
patterns=["rm/*"],
rules=[Rule(permission="rm", pattern="*", action="ask")],
emit_interrupt=True,
)
return {"final_decision": decision}
g = StateGraph(_State)
g.add_node("perm", perm_node)
g.add_edge(START, "perm")
g.add_edge("perm", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_permission_ask_payload_uses_lc_hitl_shape():
"""The permission middleware now speaks the langchain HITL standard shape."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-wire"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1
value = snap.interrupts[0].value
assert value.get("action_requests") == [
{"name": "rm", "args": {"path": "/tmp/file"}}
], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}"
review = value.get("review_configs")
assert isinstance(review, list) and len(review) == 1
# ``always`` must be in the palette so the FE can render the promote button.
assert "always" in review[0]["allowed_decisions"]
assert value.get("interrupt_type") == "permission_ask"
# SurfSense context rides through verbatim for FE explainability.
assert value["context"]["patterns"] == ["rm/*"]
assert value["context"]["always"] == ["rm/*"]
@pytest.mark.asyncio
async def test_resume_with_approve_envelope_returns_once_decision():
"""``approve`` from the LC envelope projects to permission-domain ``once``."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-once"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve"}]}), config
)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {"decision_type": "once"}
@pytest.mark.asyncio
async def test_resume_with_always_envelope_projects_to_always():
"""``always`` reply must project unchanged so the middleware can promote the rule."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-always"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "always"}]}), config
)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {"decision_type": "always"}
@pytest.mark.asyncio
async def test_resume_with_reject_and_feedback_carries_feedback_through():
"""Reject feedback must survive normalization for ``CorrectedError`` to fire downstream."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-reject"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(
resume={
"decisions": [{"type": "reject", "feedback": "use the trash bin"}]
}
),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {
"decision_type": "reject",
"feedback": "use the trash bin",
}

View file

@ -0,0 +1,169 @@
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
The KB unification swap (legacy ``interrupt_on`` map KB-owned ``Ruleset``
threaded through ``build_permission_mw(extra_rulesets=...)``) must produce
*exactly one* interrupt per destructive FS call, in LC HITL shape, even
when ``enable_permission`` is False destructive ops always ask.
We exercise the production factory and a real ``PermissionMiddleware`` on a
real ``StateGraph`` so the test catches regressions in factory gating,
ruleset layering, and interrupt emission together.
"""
from __future__ import annotations
from typing import Annotated, Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions import (
build_permission_mw,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
def _kb_style_ruleset() -> Ruleset:
"""Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it.
Importing the agent module pulls in deepagents and prompts; this test
is about the factory + middleware contract, not KB wiring.
"""
return Ruleset(
origin="knowledge_base",
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
)
class _State(TypedDict, total=False):
messages: Annotated[list, add_messages]
def _build_graph_with_permission_middleware(
*,
flags: AgentFeatureFlags,
extra_rulesets: list[Ruleset] | None,
checkpointer: InMemorySaver,
):
"""Compile a one-node graph that emits a tool call for ``rm`` and
routes through the production ``PermissionMiddleware``.
The node returns an ``AIMessage`` with a tool call. The middleware's
``after_model`` hook intercepts and (if a rule says ``ask``) raises
a ``GraphInterrupt`` carrying the LC HITL payload.
"""
pm = build_permission_mw(flags=flags, extra_rulesets=extra_rulesets)
def node(_state: _State) -> dict[str, Any]:
msg = AIMessage(
content="",
tool_calls=[
{
"name": "rm",
"args": {"path": "/tmp/foo"},
"id": "call-rm-1",
"type": "tool_call",
}
],
)
return {"messages": [msg]}
def after_node(state: _State) -> dict[str, Any] | None:
if pm is None:
return None
# PermissionMiddleware._process ignores runtime; the test never relies
# on the runtime context, so passing None keeps the harness lean.
return pm._process(state, None) # type: ignore[arg-type]
g = StateGraph(_State)
g.add_node("emit", node)
g.add_node("permission", after_node)
g.add_edge(START, "emit")
g.add_edge("emit", "permission")
g.add_edge("permission", END)
return g.compile(checkpointer=checkpointer), pm
@pytest.mark.asyncio
async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off():
"""KB ruleset: ``rm`` must ask once even with ``enable_permission=False``.
This is the keystone of the unification: the legacy ``interrupt_on``
map fired regardless of ``enable_permission``, so the migrated rules
must too. Otherwise users could opt out of "ask before rm".
"""
flags = AgentFeatureFlags(enable_permission=False)
checkpointer = InMemorySaver()
graph, pm = _build_graph_with_permission_middleware(
flags=flags,
extra_rulesets=[_kb_style_ruleset()],
checkpointer=checkpointer,
)
assert pm is not None, "extras must force the middleware on"
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1, (
f"REGRESSION: KB ruleset should raise exactly one interrupt; got "
f"{[i.value for i in snap.interrupts]!r}"
)
payload = snap.interrupts[0].value
requests = payload.get("action_requests")
assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], (
f"interrupt must carry the rm call in LC HITL shape; got {payload!r}"
)
assert payload.get("interrupt_type") == "permission_ask"
@pytest.mark.asyncio
async def test_kb_ruleset_resume_with_approve_lets_rm_through():
"""Resume with ``approve`` → call kept; the model continues normally."""
flags = AgentFeatureFlags(enable_permission=False)
checkpointer = InMemorySaver()
graph, _ = _build_graph_with_permission_middleware(
flags=flags,
extra_rulesets=[_kb_style_ruleset()],
checkpointer=checkpointer,
)
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve"}]}), config
)
final = await graph.aget_state(config)
assert final.next == (), "graph must complete after approve"
last_ai = next(
(m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)),
None,
)
assert last_ai is not None
assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], (
"approved rm call must remain on the AIMessage so the tool can run"
)
@pytest.mark.asyncio
async def test_no_extras_with_permission_off_skips_middleware_entirely():
"""No extras + permission off → factory returns ``None`` (no engine).
The legacy gating is preserved when no caller asks for rules: nothing
runs, nothing pauses.
"""
flags = AgentFeatureFlags(enable_permission=False)
pm = build_permission_mw(flags=flags, extra_rulesets=None)
assert pm is None

View file

@ -0,0 +1,132 @@
"""Regression: ``request_approval`` must emit the unified LC HITL wire shape.
Before this fix, self-gated approvals fired the SurfSense-specific
``{type, action, context}`` shape which the parallel-HITL routing layer
(``collect_pending_tool_calls``) does not recognize. In a parallel HITL
scenario where one subagent used self-gated approvals (e.g. Gmail send)
and another used middleware-gated approvals (e.g. Linear via
``HumanInTheLoopMiddleware``), the routing layer would silently skip the
self-gated interrupt and crash on resume with ``Decision count mismatch``.
This test pins the wire contract by running ``request_approval`` inside a
real ``StateGraph`` and asserting the paused parent observes the LC HITL
shape (``action_requests``, ``review_configs``, ``interrupt_type``).
"""
from __future__ import annotations
import pytest
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
class _State(TypedDict, total=False):
messages: list
final_decision_type: str
final_params: dict
def _build_graph_calling_request_approval(checkpointer: InMemorySaver):
"""A real graph whose only node delegates to ``request_approval``."""
def gate_node(_state):
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com", "subject": "hi"},
context={"account": "alice@gmail.com"},
)
return {
"final_decision_type": result.decision_type,
"final_params": result.params,
}
g = StateGraph(_State)
g.add_node("gate", gate_node)
g.add_edge(START, "gate")
g.add_edge("gate", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_paused_interrupt_uses_lc_hitl_action_requests_shape():
"""The paused interrupt must speak the langchain HITL standard shape."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-wire"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1, (
f"expected one paused interrupt, got {len(snap.interrupts)}"
)
value = snap.interrupts[0].value
assert isinstance(value, dict)
# Standard LC HITL fields the routing layer reads.
assert value.get("action_requests") == [
{
"name": "send_gmail_email",
"args": {"to": "alice@example.com", "subject": "hi"},
}
], (
"REGRESSION: self-gated approval reverted to legacy SurfSense shape; "
f"got {value!r}"
)
assert value.get("review_configs") == [
{
"action_name": "send_gmail_email",
"allowed_decisions": ["approve", "reject", "edit"],
}
]
assert value.get("interrupt_type") == "gmail_email_send", (
"FE card discriminator must travel as ``interrupt_type``."
)
assert value.get("context") == {"account": "alice@gmail.com"}
@pytest.mark.asyncio
async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args():
"""Edit reply via the LC envelope must round-trip into ``HITLResult.params``."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-resume"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
edited = {"to": "alice@example.com", "subject": "EDITED"}
await graph.ainvoke(
Command(
resume={
"decisions": [
{"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}}
]
}
),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision_type") == "edit"
assert final.values.get("final_params") == edited
@pytest.mark.asyncio
async def test_reject_envelope_returns_rejected_hitl_result():
"""Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-reject"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision_type") == "reject"

View file

@ -0,0 +1,168 @@
"""Unit contract for the unified LC HITL wire format.
Both the self-gated approval primitive (``request_approval``) and the
middleware-gated permission ask (``PermissionMiddleware``) must serialize
to the same wire shape so the parallel-HITL routing layer
(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` +
``build_lg_resume_map``) sees one format.
These tests pin the shape:
- Builder always emits ``action_requests`` (1 entry) + ``review_configs``
+ ``interrupt_type``; ``context`` rides through verbatim when present.
- Parser tolerates the standard LC envelope, bare scalar strings, and
unrecognized shapes (failing closed to ``reject``).
- Edited args round-trip through both nested (``edited_action.args``) and
flat (``args``) shapes without inventing values for the empty case.
"""
from __future__ import annotations
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
LC_DECISION_APPROVE,
LC_DECISION_EDIT,
LC_DECISION_REJECT,
SURFSENSE_DECISION_ALWAYS,
build_lc_hitl_payload,
parse_lc_envelope,
)
class TestBuildLcHitlPayload:
def test_minimal_payload_has_one_action_request_and_one_review_config(self):
payload = build_lc_hitl_payload(
tool_name="send_email",
args={"to": "x@y.z"},
allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT],
interrupt_type="gmail_email_send",
)
assert payload["action_requests"] == [
{"name": "send_email", "args": {"to": "x@y.z"}}
]
assert payload["review_configs"] == [
{
"action_name": "send_email",
"allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT],
}
]
assert payload["interrupt_type"] == "gmail_email_send"
assert "context" not in payload, "context must be omitted when not provided"
def test_none_args_normalized_to_empty_dict(self):
"""FE expects a stable shape; ``None`` would crash card rendering."""
payload = build_lc_hitl_payload(
tool_name="ping",
args=None, # type: ignore[arg-type]
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="self_gated",
)
assert payload["action_requests"][0]["args"] == {}
def test_description_attached_only_when_provided(self):
with_desc = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="x",
description="please review",
)
without = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="x",
)
assert with_desc["action_requests"][0]["description"] == "please review"
assert "description" not in without["action_requests"][0]
def test_context_passed_through_verbatim(self):
ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]}
payload = build_lc_hitl_payload(
tool_name="rm",
args={"path": "/tmp"},
allowed_decisions=[
LC_DECISION_APPROVE,
LC_DECISION_REJECT,
SURFSENSE_DECISION_ALWAYS,
],
interrupt_type="permission_ask",
context=ctx,
)
assert payload["context"] == ctx
def test_allowed_decisions_list_is_copied_not_aliased(self):
"""A caller mutating their original list must not corrupt the payload."""
decisions = [LC_DECISION_APPROVE]
payload = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=decisions,
interrupt_type="x",
)
decisions.append(LC_DECISION_REJECT)
assert payload["review_configs"][0]["allowed_decisions"] == [LC_DECISION_APPROVE]
class TestParseLcEnvelope:
def test_standard_lc_envelope_returns_typed_decision(self):
parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]})
assert parsed.decision_type == "approve"
assert parsed.edited_args is None
assert parsed.message is None
def test_bare_scalar_string_passes_through_lowercased(self):
assert parse_lc_envelope("ALWAYS").decision_type == "always"
assert parse_lc_envelope("once").decision_type == "once"
def test_non_dict_non_string_collapses_to_reject(self):
"""Failing closed: ambiguous input must never proceed."""
assert parse_lc_envelope(42).decision_type == "reject"
assert parse_lc_envelope(None).decision_type == "reject"
assert parse_lc_envelope(["bogus"]).decision_type == "reject"
def test_missing_decision_type_collapses_to_reject(self):
assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject"
assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject"
def test_edit_extracts_nested_args(self):
parsed = parse_lc_envelope(
{
"decisions": [
{
"type": LC_DECISION_EDIT,
"edited_action": {"args": {"to": "edited@y.z"}},
}
]
}
)
assert parsed.decision_type == "edit"
assert parsed.edited_args == {"to": "edited@y.z"}
def test_edit_falls_back_to_flat_args(self):
parsed = parse_lc_envelope(
{"decisions": [{"type": "edit", "args": {"k": "v"}}]}
)
assert parsed.edited_args == {"k": "v"}
def test_edit_with_empty_args_yields_none_edited(self):
"""Empty edited_args means "no edits" — caller treats as plain approve."""
parsed = parse_lc_envelope(
{"decisions": [{"type": "edit", "edited_action": {"args": {}}}]}
)
assert parsed.edited_args is None
def test_message_picked_from_either_feedback_or_message_field(self):
with_feedback = parse_lc_envelope(
{"decisions": [{"type": "reject", "feedback": "no thanks"}]}
)
with_message = parse_lc_envelope(
{"decisions": [{"type": "reject", "message": "no thanks"}]}
)
assert with_feedback.message == "no thanks"
assert with_message.message == "no thanks"
def test_blank_message_treated_as_absent(self):
parsed = parse_lc_envelope(
{"decisions": [{"type": "reject", "message": " "}]}
)
assert parsed.message is None