Merge branch 'dev' into feat/e2e-testing

This commit is contained in:
Rohan Verma 2026-05-09 16:10:45 -07:00 committed by GitHub
commit fa31da9937
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
100 changed files with 3751 additions and 1122 deletions

View file

@ -0,0 +1,208 @@
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
from __future__ import annotations
import ast
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubagentState(TypedDict, total=False):
messages: list
decision_text: str
def _build_single_interrupt_subagent():
def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{
"name": "do_thing",
"args": {"x": 1},
"description": "test action",
}
],
"review_configs": [{}],
}
)
return {
"messages": [AIMessage(content="done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("approve", approve_node)
graph.add_edge(START, "approve")
graph.add_edge("approve", END)
return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=config,
stream_writer=None,
tool_call_id="parent-tcid-1",
store=None,
)
@pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt"
)
parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"]
}
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
assert isinstance(result, Command)
update = result.update
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks)
@pytest.mark.asyncio
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"""Bridge must fail loud rather than silently replay the user's interrupt."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config)
with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
def _build_bundle_subagent():
def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
}
)
return {
"messages": [AIMessage(content="bundle-done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("bundle", bundle_node)
graph.add_edge(START, "bundle")
graph.add_edge("bundle", END)
return graph.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_bundle_three_mixed_decisions_arrive_in_order():
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
subagent = _build_bundle_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "creates a bundle",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
decisions_payload = {
"decisions": [
{"type": "approve", "args": {}},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
]
}
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime,
)
assert isinstance(result, Command)
received = ast.literal_eval(result.update["decision_text"])
assert received == decisions_payload
assert received["decisions"][0]["type"] == "approve"
assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject"

View file

@ -0,0 +1,55 @@
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
The bridge currently relies on at-most-one pending interrupt per snapshot
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
needs an id-aware lookup; these tests will need to be revisited at that point.
"""
from __future__ import annotations
from types import SimpleNamespace
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
get_first_pending_subagent_interrupt,
)
class TestGetFirstPendingSubagentInterrupt:
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
state = SimpleNamespace(interrupts=(first, second), tasks=())
assert get_first_pending_subagent_interrupt(state) == (
"i-1",
{"decision": "approve"},
)
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
first = SimpleNamespace(id="i-A", value="approve")
second = SimpleNamespace(id="i-B", value="reject")
sub_task = SimpleNamespace(interrupts=(first, second))
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
def test_returns_none_when_no_interrupts(self):
state = SimpleNamespace(interrupts=(), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, None)
def test_returns_none_when_state_is_none(self):
assert get_first_pending_subagent_interrupt(None) == (None, None)
def test_skips_interrupts_with_none_value(self):
empty = SimpleNamespace(id="i-empty", value=None)
real = SimpleNamespace(id="i-real", value="approve")
state = SimpleNamespace(interrupts=(empty, real), tasks=())
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
def test_normalizes_non_string_id_to_none(self):
interrupt = SimpleNamespace(id=12345, value="approve")
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, "approve")

View file

@ -0,0 +1,71 @@
"""Resume side-channel must be read exactly once per turn."""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
consume_surfsense_resume,
has_surfsense_resume,
)
def _runtime_with_config(config: dict) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
consume_surfsense_resume(runtime)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
assert consume_surfsense_resume(runtime) is None
def test_returns_none_when_configurable_missing(self):
runtime = _runtime_with_config({})
assert consume_surfsense_resume(runtime) is None
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
)
assert has_surfsense_resume(runtime) is True
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})
assert has_surfsense_resume(runtime) is False

View file

@ -0,0 +1,96 @@
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain.agents import create_agent
from langchain.agents.middleware import ModelFallbackMiddleware
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
pack_subagent,
)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
class _AlwaysFailingChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "always-failing-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
yield # pragma: no cover - unreachable, satisfies async generator typing
@pytest.mark.asyncio
async def test_subagent_recovers_when_primary_llm_fails():
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
primary = _AlwaysFailingChatModel()
fallback = FakeMessagesListChatModel(
responses=[AIMessage(content="recovered via fallback")]
)
spec = pack_subagent(
name="resilience_test",
description="test subagent",
system_prompt="be helpful",
tools=[],
model=primary,
extra_middleware=[ModelFallbackMiddleware(fallback)],
)
agent = create_agent(
model=spec["model"],
tools=spec["tools"],
middleware=spec["middleware"],
system_prompt=spec["system_prompt"],
)
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered via fallback"

View file

@ -0,0 +1,128 @@
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
class _RaisingChatModel(BaseChatModel):
exc_to_raise: Any
@property
def _llm_type(self) -> str:
return "raising-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
raise self.exc_to_raise
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
raise self.exc_to_raise
yield # pragma: no cover - unreachable
class _RecordingChatModel(BaseChatModel):
response_text: str = "fallback-ok"
call_count: int = 0
@property
def _llm_type(self) -> str:
return "recording-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
self.call_count += 1
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=self.response_text))]
)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return self._generate(messages, stop, None, **kwargs)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
from langchain.agents import create_agent
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
return create_agent(
model=primary,
tools=[],
middleware=[ScopedModelFallbackMiddleware(fallback)],
system_prompt="be helpful",
)
@pytest.mark.asyncio
async def test_provider_errors_trigger_fallback():
"""Eligible exception names must drive the fallback chain."""
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
fallback = _RecordingChatModel(response_text="recovered")
agent = _build_agent(primary, fallback)
result = await agent.ainvoke({"messages": [("user", "hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered"
assert fallback.call_count == 1
@pytest.mark.asyncio
async def test_programming_errors_propagate_without_invoking_fallback():
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
fallback = _RecordingChatModel(response_text="should-never-arrive")
agent = _build_agent(primary, fallback)
with pytest.raises(KeyError, match="missing_state_field"):
await agent.ainvoke({"messages": [("user", "hi")]})
assert fallback.call_count == 0

View file

@ -202,6 +202,15 @@ class FakeBudgetLLM:
class TestKnowledgeBaseSearchMiddlewarePlanner:
@pytest.fixture(autouse=True)
def _disable_planner_runnable(self, monkeypatch):
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
# which the mock does not implement. Pin the flag off so the
# planner falls through to the legacy ``self.llm.ainvoke`` path
# these tests assert against (``llm.calls[0]["config"]``).
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [
HumanMessage(content="old user context " * 40),