mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
Merge branch 'dev' into feat/e2e-testing
This commit is contained in:
commit
fa31da9937
100 changed files with 3751 additions and 1122 deletions
|
|
@ -0,0 +1,208 @@
|
|||
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubagentState(TypedDict, total=False):
|
||||
messages: list
|
||||
decision_text: str
|
||||
|
||||
|
||||
def _build_single_interrupt_subagent():
|
||||
def approve_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{
|
||||
"name": "do_thing",
|
||||
"args": {"x": 1},
|
||||
"description": "test action",
|
||||
}
|
||||
],
|
||||
"review_configs": [{}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("approve", approve_node)
|
||||
graph.add_edge(START, "approve")
|
||||
graph.add_edge("approve", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _make_runtime(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"messages": [HumanMessage(content="seed")]},
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="parent-tcid-1",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "shared-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, (
|
||||
"fixture broken: subagent should be paused on its interrupt"
|
||||
)
|
||||
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
"decisions": ["APPROVED"]
|
||||
}
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
update = result.update
|
||||
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
||||
final = await subagent.aget_state(parent_config)
|
||||
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||
"""Bridge must fail loud rather than silently replay the user's interrupt."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "guard-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||
await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def _build_bundle_subagent():
|
||||
def bundle_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "create_a", "args": {}, "description": ""},
|
||||
{"name": "create_b", "args": {}, "description": ""},
|
||||
{"name": "create_c", "args": {}, "description": ""},
|
||||
],
|
||||
"review_configs": [{}, {}, {}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="bundle-done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("bundle", bundle_node)
|
||||
graph.add_edge(START, "bundle")
|
||||
graph.add_edge("bundle", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
|
||||
subagent = _build_bundle_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "bundler",
|
||||
"description": "creates a bundle",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "bundle-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
decisions_payload = {
|
||||
"decisions": [
|
||||
{"type": "approve", "args": {}},
|
||||
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
|
||||
{"type": "reject", "args": {"message": "no thanks"}},
|
||||
]
|
||||
}
|
||||
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="run bundle",
|
||||
subagent_type="bundler",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
received = ast.literal_eval(result.update["decision_text"])
|
||||
assert received == decisions_payload
|
||||
assert received["decisions"][0]["type"] == "approve"
|
||||
assert received["decisions"][1]["type"] == "edit"
|
||||
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||
assert received["decisions"][2]["type"] == "reject"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
|
||||
|
||||
The bridge currently relies on at-most-one pending interrupt per snapshot
|
||||
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
|
||||
needs an id-aware lookup; these tests will need to be revisited at that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
|
||||
get_first_pending_subagent_interrupt,
|
||||
)
|
||||
|
||||
|
||||
class TestGetFirstPendingSubagentInterrupt:
|
||||
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
|
||||
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
|
||||
state = SimpleNamespace(interrupts=(first, second), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (
|
||||
"i-1",
|
||||
{"decision": "approve"},
|
||||
)
|
||||
|
||||
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-A", value="approve")
|
||||
second = SimpleNamespace(id="i-B", value="reject")
|
||||
sub_task = SimpleNamespace(interrupts=(first, second))
|
||||
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
|
||||
|
||||
def test_returns_none_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=(), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, None)
|
||||
|
||||
def test_returns_none_when_state_is_none(self):
|
||||
assert get_first_pending_subagent_interrupt(None) == (None, None)
|
||||
|
||||
def test_skips_interrupts_with_none_value(self):
|
||||
empty = SimpleNamespace(id="i-empty", value=None)
|
||||
real = SimpleNamespace(id="i-real", value="approve")
|
||||
state = SimpleNamespace(interrupts=(empty, real), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
|
||||
|
||||
def test_normalizes_non_string_id_to_none(self):
|
||||
interrupt = SimpleNamespace(id=12345, value="approve")
|
||||
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, "approve")
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
consume_surfsense_resume,
|
||||
has_surfsense_resume,
|
||||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_returns_none_when_configurable_missing(self):
|
||||
runtime = _runtime_with_config({})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeMessagesListChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
class _AlwaysFailingChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "always-failing-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_recovers_when_primary_llm_fails():
|
||||
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||
primary = _AlwaysFailingChatModel()
|
||||
fallback = FakeMessagesListChatModel(
|
||||
responses=[AIMessage(content="recovered via fallback")]
|
||||
)
|
||||
|
||||
spec = pack_subagent(
|
||||
name="resilience_test",
|
||||
description="test subagent",
|
||||
system_prompt="be helpful",
|
||||
tools=[],
|
||||
model=primary,
|
||||
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=spec["model"],
|
||||
tools=spec["tools"],
|
||||
middleware=spec["middleware"],
|
||||
system_prompt=spec["system_prompt"],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered via fallback"
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class _RaisingChatModel(BaseChatModel):
|
||||
exc_to_raise: Any
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "raising-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
yield # pragma: no cover - unreachable
|
||||
|
||||
|
||||
class _RecordingChatModel(BaseChatModel):
|
||||
response_text: str = "fallback-ok"
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "recording-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.call_count += 1
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=self.response_text))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return self._generate(messages, stop, None, **kwargs)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
return create_agent(
|
||||
model=primary,
|
||||
tools=[],
|
||||
middleware=[ScopedModelFallbackMiddleware(fallback)],
|
||||
system_prompt="be helpful",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_errors_trigger_fallback():
|
||||
"""Eligible exception names must drive the fallback chain."""
|
||||
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||
fallback = _RecordingChatModel(response_text="recovered")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
result = await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered"
|
||||
assert fallback.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
|
||||
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
|
||||
with pytest.raises(KeyError, match="missing_state_field"):
|
||||
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
assert fallback.call_count == 0
|
||||
|
|
@ -202,6 +202,15 @@ class FakeBudgetLLM:
|
|||
|
||||
|
||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_planner_runnable(self, monkeypatch):
|
||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||
# which the mock does not implement. Pin the flag off so the
|
||||
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||
# these tests assert against (``llm.calls[0]["config"]``).
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue