mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
chore: trim narrative comments and docstrings
This commit is contained in:
parent
309c695531
commit
9a4ee5d16b
6 changed files with 24 additions and 129 deletions
|
|
@ -210,10 +210,8 @@ def build_main_agent_deepagent_middleware(
|
||||||
)
|
)
|
||||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
||||||
|
|
||||||
# Defined here (instead of further down with the other ``wrap_model_call``
|
# Defined early so the same instances reach both gp_middleware and
|
||||||
# middlewares) so subagents share the same instances as the parent —
|
# subagent_extra_middleware below.
|
||||||
# otherwise a connector subagent would die on the first provider hiccup
|
|
||||||
# while the parent stays resilient.
|
|
||||||
retry_mw = (
|
retry_mw = (
|
||||||
RetryAfterMiddleware(max_retries=3)
|
RetryAfterMiddleware(max_retries=3)
|
||||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||||
|
|
@ -230,9 +228,7 @@ def build_main_agent_deepagent_middleware(
|
||||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||||
fallback_mw = None
|
fallback_mw = None
|
||||||
|
|
||||||
# Cost / loop ceiling shared with subagents. ``state_schema`` of these
|
# Per-agent caps; counts are not summed across parent + subagents.
|
||||||
# middlewares is per-agent; counts are not summed across parent + sub —
|
|
||||||
# the cap acts as a safety net per agent, not a global budget.
|
|
||||||
model_call_limit_mw = (
|
model_call_limit_mw = (
|
||||||
ModelCallLimitMiddleware(
|
ModelCallLimitMiddleware(
|
||||||
thread_limit=120,
|
thread_limit=120,
|
||||||
|
|
@ -250,9 +246,8 @@ def build_main_agent_deepagent_middleware(
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mirror the parent's ordering: retry / fallback / limits wrap caching,
|
# gp_middleware is held by reference inside general_purpose_spec, so
|
||||||
# which wraps the model. ``gp_middleware`` is held by reference inside
|
# mutating it here propagates into the spec.
|
||||||
# ``general_purpose_spec`` so this insertion propagates into the spec.
|
|
||||||
_gp_resilience: list[Any] = [
|
_gp_resilience: list[Any] = [
|
||||||
m
|
m
|
||||||
for m in (retry_mw, fallback_mw, model_call_limit_mw, tool_call_limit_mw)
|
for m in (retry_mw, fallback_mw, model_call_limit_mw, tool_call_limit_mw)
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,4 @@
|
||||||
"""Fallback only on provider/network errors; let programming bugs raise.
|
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||||
|
|
||||||
Upstream :class:`langchain.agents.middleware.ModelFallbackMiddleware` catches
|
|
||||||
every ``Exception``. With a non-provider bug (``KeyError``, ``TypeError``,
|
|
||||||
``AttributeError`` from middleware/state), every fallback model in the chain
|
|
||||||
hits the same bug — burning latency and tokens before the real cause finally
|
|
||||||
surfaces. Scoping the catch to provider-style exception types lets bugs fail
|
|
||||||
fast with clean tracebacks.
|
|
||||||
|
|
||||||
Class-name matching (instead of ``isinstance`` against imported provider
|
|
||||||
types) keeps the dependency surface flat: openai, anthropic, google,
|
|
||||||
mistral, etc. all ship their own ``RateLimitError`` and we don't want to
|
|
||||||
import them all.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -26,17 +13,16 @@ if TYPE_CHECKING:
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
# Matched by class name across the MRO so we don't have to import every
|
||||||
|
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
# Rate / quota
|
|
||||||
"RateLimitError",
|
"RateLimitError",
|
||||||
# Server-side
|
|
||||||
"APIStatusError",
|
"APIStatusError",
|
||||||
"InternalServerError",
|
"InternalServerError",
|
||||||
"ServiceUnavailableError",
|
"ServiceUnavailableError",
|
||||||
"BadGatewayError",
|
"BadGatewayError",
|
||||||
"GatewayTimeoutError",
|
"GatewayTimeoutError",
|
||||||
# Network
|
|
||||||
"APIConnectionError",
|
"APIConnectionError",
|
||||||
"APITimeoutError",
|
"APITimeoutError",
|
||||||
"ConnectError",
|
"ConnectError",
|
||||||
|
|
@ -45,18 +31,16 @@ _FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||||
"RemoteProtocolError",
|
"RemoteProtocolError",
|
||||||
"TimeoutError",
|
"TimeoutError",
|
||||||
"TimeoutException",
|
"TimeoutException",
|
||||||
# Can be extended to other exceptions in the future
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||||
"""Eligible if the exception or any base in its MRO matches by class name."""
|
|
||||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||||
|
|
||||||
|
|
||||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||||
"""``ModelFallbackMiddleware`` that re-raises non-provider exceptions."""
|
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||||
|
|
||||||
def wrap_model_call( # type: ignore[override]
|
def wrap_model_call( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,4 @@
|
||||||
"""End-to-end resume-bridge tests against a real LangGraph subagent.
|
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
|
||||||
|
|
||||||
Builds a minimal Pregel subagent that calls ``interrupt(...)`` and drives the
|
|
||||||
``task`` tool directly with a hand-crafted ``ToolRuntime``. Exercises the only
|
|
||||||
runtime contract we own: parent stashes a decision in
|
|
||||||
``config["configurable"]["surfsense_resume_value"]`` -> bridge forwards it as
|
|
||||||
``Command(resume={interrupt_id: value})`` -> subagent completes -> return value
|
|
||||||
reflects the decision.
|
|
||||||
|
|
||||||
We pause the subagent **outside** the parent task tool (calling
|
|
||||||
``subagent.ainvoke`` directly) to skip the ``_lg_interrupt`` re-raise path,
|
|
||||||
which requires a parent runnable context. The bridge logic under test is the
|
|
||||||
*resume* dispatch, not the propagation; propagation is exercised separately in
|
|
||||||
its own module's tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -37,8 +23,6 @@ class _SubagentState(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
def _build_single_interrupt_subagent():
|
def _build_single_interrupt_subagent():
|
||||||
"""Subagent that interrupts once, then echoes the resume decision into state."""
|
|
||||||
|
|
||||||
def approve_node(state):
|
def approve_node(state):
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
@ -54,8 +38,6 @@ def _build_single_interrupt_subagent():
|
||||||
"review_configs": [{}],
|
"review_configs": [{}],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Capture the resume payload verbatim so the test can assert the
|
|
||||||
# bridge forwarded it intact (no reshape, no scalar broadcast).
|
|
||||||
return {
|
return {
|
||||||
"messages": [AIMessage(content="done")],
|
"messages": [AIMessage(content="done")],
|
||||||
"decision_text": repr(decision),
|
"decision_text": repr(decision),
|
||||||
|
|
@ -81,7 +63,7 @@ def _make_runtime(config: dict) -> ToolRuntime:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||||
"""Side-channel decision -> targeted Command(resume) -> subagent completes."""
|
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||||
subagent = _build_single_interrupt_subagent()
|
subagent = _build_single_interrupt_subagent()
|
||||||
task_tool = build_task_tool_with_parent_config(
|
task_tool = build_task_tool_with_parent_config(
|
||||||
[
|
[
|
||||||
|
|
@ -93,7 +75,6 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Pause the subagent directly so we can test only the resume path.
|
|
||||||
parent_config: dict = {
|
parent_config: dict = {
|
||||||
"configurable": {"thread_id": "shared-thread"},
|
"configurable": {"thread_id": "shared-thread"},
|
||||||
"recursion_limit": 100,
|
"recursion_limit": 100,
|
||||||
|
|
@ -104,15 +85,11 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||||
"fixture broken: subagent should be paused on its interrupt"
|
"fixture broken: subagent should be paused on its interrupt"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Stash the user's decision on the side-channel — this is what
|
|
||||||
# ``stream_resume_chat`` does in production.
|
|
||||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||||
"decisions": ["APPROVED"]
|
"decisions": ["APPROVED"]
|
||||||
}
|
}
|
||||||
runtime = _make_runtime(parent_config)
|
runtime = _make_runtime(parent_config)
|
||||||
|
|
||||||
# 3. Drive the bridge. Subagent has no remaining interrupt after resume,
|
|
||||||
# so propagation will not call ``_lg_interrupt`` (no parent ctx needed).
|
|
||||||
result = await task_tool.coroutine(
|
result = await task_tool.coroutine(
|
||||||
description="please approve",
|
description="please approve",
|
||||||
subagent_type="approver",
|
subagent_type="approver",
|
||||||
|
|
@ -121,27 +98,16 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||||
|
|
||||||
assert isinstance(result, Command)
|
assert isinstance(result, Command)
|
||||||
update = result.update
|
update = result.update
|
||||||
# Bridge forwards the side-channel payload **verbatim** to the
|
|
||||||
# subagent's ``interrupt()``. A scalar broadcast or accidental
|
|
||||||
# unwrap would change this shape and we want to catch that.
|
|
||||||
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||||
|
|
||||||
# 4. Side-channel was consumed; a stale replay would re-prompt the user.
|
|
||||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||||
|
|
||||||
# 5. Subagent moved past the interrupt (no pending tasks remain).
|
|
||||||
final = await subagent.aget_state(parent_config)
|
final = await subagent.aget_state(parent_config)
|
||||||
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||||
"""Bridge must fail loud if a paused subagent has no decision queued.
|
"""Bridge must fail loud rather than silently replay the user's interrupt."""
|
||||||
|
|
||||||
The fail-open alternative (silently re-invoking) would re-fire the
|
|
||||||
same interrupt to the user. The error surfaces a real broken bridge
|
|
||||||
instead of confusing duplicate approval cards.
|
|
||||||
"""
|
|
||||||
subagent = _build_single_interrupt_subagent()
|
subagent = _build_single_interrupt_subagent()
|
||||||
task_tool = build_task_tool_with_parent_config(
|
task_tool = build_task_tool_with_parent_config(
|
||||||
[
|
[
|
||||||
|
|
@ -161,7 +127,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||||
snap = await subagent.aget_state(parent_config)
|
snap = await subagent.aget_state(parent_config)
|
||||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||||
|
|
||||||
# No surfsense_resume_value injected — bridge must refuse to proceed.
|
|
||||||
runtime = _make_runtime(parent_config)
|
runtime = _make_runtime(parent_config)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||||
|
|
@ -173,8 +138,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||||
|
|
||||||
|
|
||||||
def _build_bundle_subagent():
|
def _build_bundle_subagent():
|
||||||
"""Subagent that raises a 3-action HITL bundle on its only node."""
|
|
||||||
|
|
||||||
def bundle_node(state):
|
def bundle_node(state):
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
@ -202,12 +165,7 @@ def _build_bundle_subagent():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||||
"""Approve / edit / reject for a 3-action bundle land at ordinals 0/1/2.
|
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
|
||||||
|
|
||||||
Catches reshape regressions: truncation, decision collapse, order
|
|
||||||
scrambling, and the legacy single-decision broadcast that would
|
|
||||||
fan-out one verdict to every action.
|
|
||||||
"""
|
|
||||||
subagent = _build_bundle_subagent()
|
subagent = _build_bundle_subagent()
|
||||||
task_tool = build_task_tool_with_parent_config(
|
task_tool = build_task_tool_with_parent_config(
|
||||||
[
|
[
|
||||||
|
|
@ -242,11 +200,8 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, Command)
|
assert isinstance(result, Command)
|
||||||
decision_text = result.update["decision_text"]
|
received = ast.literal_eval(result.update["decision_text"])
|
||||||
received = ast.literal_eval(decision_text)
|
assert received == decisions_payload
|
||||||
assert received == decisions_payload, "bundle decisions must arrive verbatim"
|
|
||||||
# Cross-checks for the regressions this test exists to catch.
|
|
||||||
assert len(received["decisions"]) == 3
|
|
||||||
assert received["decisions"][0]["type"] == "approve"
|
assert received["decisions"][0]["type"] == "approve"
|
||||||
assert received["decisions"][1]["type"] == "edit"
|
assert received["decisions"][1]["type"] == "edit"
|
||||||
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,4 @@
|
||||||
"""Pure-function tests for the HITL resume side-channel helpers.
|
"""Resume side-channel must be read exactly once per turn."""
|
||||||
|
|
||||||
Tests the invariant that backs the bridge: a queued resume value must be
|
|
||||||
read exactly once per turn. A second read returns ``None`` so the
|
|
||||||
parent ``task`` tool falls through to its fail-loud guard rather than
|
|
||||||
replaying the same resume payload (which would re-fire the interrupt).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -17,7 +11,6 @@ from app.agents.multi_agent_chat.main_agent.graph.middleware.checkpointed_subage
|
||||||
|
|
||||||
|
|
||||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||||
"""Real ToolRuntime; only ``.config`` is exercised by the helpers."""
|
|
||||||
return ToolRuntime(
|
return ToolRuntime(
|
||||||
state=None,
|
state=None,
|
||||||
context=None,
|
context=None,
|
||||||
|
|
@ -37,9 +30,6 @@ class TestConsumeSurfsenseResume:
|
||||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||||
|
|
||||||
def test_second_call_returns_none(self):
|
def test_second_call_returns_none(self):
|
||||||
# Regression guard: a second read must not replay the queued
|
|
||||||
# resume. If it did, the subagent would re-invoke with the
|
|
||||||
# same Command and the user-facing interrupt would fire twice.
|
|
||||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||||
runtime = _runtime_with_config({"configurable": configurable})
|
runtime = _runtime_with_config({"configurable": configurable})
|
||||||
|
|
||||||
|
|
@ -68,9 +58,6 @@ class TestHasSurfsenseResume:
|
||||||
assert has_surfsense_resume(runtime) is True
|
assert has_surfsense_resume(runtime) is True
|
||||||
|
|
||||||
def test_does_not_consume_payload(self):
|
def test_does_not_consume_payload(self):
|
||||||
# The fail-loud guard in ``task_tool`` calls ``has_surfsense_resume``
|
|
||||||
# *before* deciding to consume; the check itself must leave the
|
|
||||||
# payload queued for the matching ``consume_surfsense_resume`` call.
|
|
||||||
configurable = {"surfsense_resume_value": "approve"}
|
configurable = {"surfsense_resume_value": "approve"}
|
||||||
runtime = _runtime_with_config({"configurable": configurable})
|
runtime = _runtime_with_config({"configurable": configurable})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,4 @@
|
||||||
"""Resilience contract for subagents built via ``pack_subagent``.
|
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||||
|
|
||||||
Subagents (jira, linear, notion, ...) run on the same LLM as the parent. When
|
|
||||||
the provider rate-limits or returns an empty stream, a single hiccup must not
|
|
||||||
abort the user's HITL flow — the connector subagent has to keep moving. This
|
|
||||||
relies on ``ModelFallbackMiddleware`` being usable as a subagent
|
|
||||||
``extra_middleware`` so the production builder can wire it in.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -32,11 +25,10 @@ from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(Exception):
|
class RateLimitError(Exception):
|
||||||
"""Provider-style 429; matches the scoped-fallback eligibility allowlist by name."""
|
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||||
|
|
||||||
|
|
||||||
class _AlwaysFailingChatModel(BaseChatModel):
|
class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
"""Mimics a provider hard-failing on every call (rate limit / empty stream)."""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
|
@ -76,7 +68,7 @@ class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subagent_recovers_when_primary_llm_fails():
|
async def test_subagent_recovers_when_primary_llm_fails():
|
||||||
"""Primary blows up → fallback in extra_middleware finishes the turn."""
|
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||||
primary = _AlwaysFailingChatModel()
|
primary = _AlwaysFailingChatModel()
|
||||||
fallback = FakeMessagesListChatModel(
|
fallback = FakeMessagesListChatModel(
|
||||||
responses=[AIMessage(content="recovered via fallback")]
|
responses=[AIMessage(content="recovered via fallback")]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,4 @@
|
||||||
"""Exception-scope contract for ``ScopedModelFallbackMiddleware``.
|
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
|
||||||
|
|
||||||
Upstream ``ModelFallbackMiddleware`` catches every ``Exception`` and walks
|
|
||||||
the fallback chain. That means a programming bug (``KeyError`` from a
|
|
||||||
botched tool config, ``TypeError`` from middleware, ...) burns 1+N model
|
|
||||||
round-trips and ~Nx tokens before its real cause surfaces. The scoped
|
|
||||||
variant only falls back on provider/network exception types so bugs fail
|
|
||||||
fast, with clean tracebacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -24,8 +16,6 @@ from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
|
||||||
class _RaisingChatModel(BaseChatModel):
|
class _RaisingChatModel(BaseChatModel):
|
||||||
"""LLM that raises a configurable exception on every invocation."""
|
|
||||||
|
|
||||||
exc_to_raise: Any
|
exc_to_raise: Any
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -61,8 +51,6 @@ class _RaisingChatModel(BaseChatModel):
|
||||||
|
|
||||||
|
|
||||||
class _RecordingChatModel(BaseChatModel):
|
class _RecordingChatModel(BaseChatModel):
|
||||||
"""Returns a fixed message and counts how often it was called."""
|
|
||||||
|
|
||||||
response_text: str = "fallback-ok"
|
response_text: str = "fallback-ok"
|
||||||
call_count: int = 0
|
call_count: int = 0
|
||||||
|
|
||||||
|
|
@ -94,14 +82,11 @@ class _RecordingChatModel(BaseChatModel):
|
||||||
return self._generate(messages, stop, None, **kwargs)
|
return self._generate(messages, stop, None, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Locally defined provider-style error: importing openai/anthropic/etc.
|
|
||||||
# would couple the test to provider SDKs the contract intentionally avoids.
|
|
||||||
class RateLimitError(Exception):
|
class RateLimitError(Exception):
|
||||||
"""Mimics ``openai.RateLimitError`` for name-based eligibility."""
|
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||||
|
|
||||||
|
|
||||||
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||||
"""Compile a no-tools agent with the scoped fallback wired in."""
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
|
@ -118,7 +103,7 @@ def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_provider_errors_trigger_fallback():
|
async def test_provider_errors_trigger_fallback():
|
||||||
"""Class names matching the provider allowlist drive the fallback chain."""
|
"""Eligible exception names must drive the fallback chain."""
|
||||||
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||||
fallback = _RecordingChatModel(response_text="recovered")
|
fallback = _RecordingChatModel(response_text="recovered")
|
||||||
|
|
||||||
|
|
@ -133,7 +118,7 @@ async def test_provider_errors_trigger_fallback():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_programming_errors_propagate_without_invoking_fallback():
|
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||||
"""``KeyError`` from agent-side bugs must surface immediately, no fallback retry."""
|
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
|
||||||
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||||
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||||
|
|
||||||
|
|
@ -142,7 +127,4 @@ async def test_programming_errors_propagate_without_invoking_fallback():
|
||||||
with pytest.raises(KeyError, match="missing_state_field"):
|
with pytest.raises(KeyError, match="missing_state_field"):
|
||||||
await agent.ainvoke({"messages": [("user", "hi")]})
|
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||||
|
|
||||||
assert fallback.call_count == 0, (
|
assert fallback.call_count == 0
|
||||||
"fallback was invoked for a programming error; "
|
|
||||||
"scoping rule is broken"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue